Skip to content

Commit

Permalink
[Core] Use zeromq to put request output tokens back to the api server (
Browse files Browse the repository at this point in the history
  • Loading branch information
s5u13b authored Sep 10, 2024
1 parent 5b9cfe5 commit c6ac5db
Show file tree
Hide file tree
Showing 15 changed files with 561 additions and 80 deletions.
81 changes: 57 additions & 24 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple
from collections import defaultdict
import threading
import asyncio
import ray
from ray.util.queue import Queue as RayQueue
from ray.util.placement_group import PlacementGroup

from vllm.engine.llm_engine import LLMEngine
Expand All @@ -32,20 +32,61 @@
from llumnix.backends.backend_interface import BackendInterface
from llumnix.backends.vllm.scheduler import SchedulerLlumnix
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix
from llumnix.backends.vllm.utils import detect_unsupported_feature
from llumnix.backends.profiling import LatencyMemData
from llumnix.server_info import ServerInfo
from llumnix.internal_config import MigrationConfig

from llumnix.rpc.queue_client import QueueClient

logger = init_logger(__name__)


class AsyncPutQueueThread(threading.Thread):
def __init__(self, instance_id):
super().__init__()
self.instance_id = instance_id
self.request_output_queue_client = QueueClient()
self.engine_actor_handle = None
self.loop = asyncio.new_event_loop()
self.daemon = True

def run(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

async def _put_nowait_batch_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
if self.engine_actor_handle is None:
self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix")
tasks = []
for server_id, req_outputs in server_request_outputs.items():
server_info = server_info_dict[server_id]
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait_batch(req_outputs, server_info)))
rets = await asyncio.gather(*tasks, return_exceptions=True)
for idx, ret in enumerate(rets):
if isinstance(ret, TimeoutError):
server_id = list(server_request_outputs.keys())[idx]
logger.info("Server {} is dead".format(server_id))
req_outputs = list(server_request_outputs.values())[idx]
request_ids = [req_output.request_id for req_output in req_outputs]
self.engine_actor_handle.abort_request.remote(request_ids)

def put_nowait_batch_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
asyncio.run_coroutine_threadsafe(self._put_nowait_batch_to_servers(server_request_outputs, server_info_dict),
self.loop)


class LLMEngineLlumnix(LLMEngine):
def __init__(self, instance_id: str, *arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.instance_id = instance_id
self.step_counter = Counter()
self.instance_info = None
# TODO(s5u13b): Reduce the overhead.
self.async_put_queue_thread = AsyncPutQueueThread(instance_id)
self.async_put_queue_thread.start()

# pylint: disable=W0221
@classmethod
Expand All @@ -61,7 +102,6 @@ def from_engine_args(
) -> "LLMEngineLlumnix":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
detect_unsupported_feature(engine_args)
engine_config = engine_args.create_engine_config()
engine_config.parallel_config.placement_group = placement_group
# Initialize the cluster and specify the executor class.
Expand Down Expand Up @@ -97,7 +137,7 @@ def _process_model_outputs(
) -> Tuple[List[RequestOutput], List[ServerInfo]]:
# ensure scheduled_seq_groups matching output
with self.scheduler.scheduler_lock:
server_info_list = []
server_infos = []
if output:
new_output = []
new_scheduled_seq_groups = []
Expand All @@ -108,18 +148,18 @@ def _process_model_outputs(
new_scheduled_seq_groups.append(scheduled_seq_group)
new_seq_group_metadata_list.append(seq_group_meta)
new_output.append(seq_group_output)
server_info_list.append(seq_group.server_info)
server_infos.append(seq_group.server_info)
scheduled_seq_groups = new_scheduled_seq_groups
output[0].outputs = new_output
seq_group_metadata_list = new_seq_group_metadata_list
for ignored_seq_group in ignored_seq_groups:
server_info_list.append(ignored_seq_group.server_info)
server_infos.append(ignored_seq_group.server_info)
request_outputs = super()._process_model_outputs(output, scheduled_seq_groups, ignored_seq_groups, seq_group_metadata_list)
# TODO(ZeldaHuang) Use LlumnixRequestOutput to store llumnix output args.
return request_outputs, server_info_list
# TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args.
return request_outputs, server_infos

def step(self) -> None:
output_list, server_info_list = super().step()
request_outputs, server_infos = super().step()

instance_info: InstanceInfo = self.instance_info
instance_info.instance_id = self.instance_id
Expand All @@ -135,7 +175,8 @@ def step(self) -> None:
tot_blocks = set(tot_blocks)
instance_info.num_blocks_last_running_request = len(tot_blocks)

self._put_request_output_to_server(output_list, server_info_list)
if request_outputs:
self._put_request_outputs_to_server(request_outputs, server_infos)
self.instance_info = instance_info

def update_instance_info(self, instance_info: InstanceInfo) -> None:
Expand All @@ -155,23 +196,16 @@ def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs)
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _put_request_output_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
def _put_request_outputs_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
server_request_outputs = defaultdict(list)
server_queue: Dict[str, RayQueue] = {}
server_info_dict = {}
# Reorganize data in orther to put request output to queue in batch at one time.
for request_output, server_info in zip(request_outputs, server_infos):
server_id = server_info.server_id
request_output_queue = server_info.request_output_queue
server_request_outputs[server_id].append(request_output)
if server_id not in server_queue:
server_queue[server_id] = request_output_queue
for server_id, req_outputs in server_request_outputs.items():
try:
server_queue[server_id].actor.put_nowait_batch.remote(req_outputs)
except ray.exceptions.RayActorError:
logger.info("Server {} is dead".format(server_id))
request_ids = [req_output.request_id for req_output in req_outputs]
self.abort_request(request_ids)
if server_id not in server_info_dict:
server_info_dict[server_id] = server_info
self.async_put_queue_thread.put_nowait_batch_to_servers(server_request_outputs, server_info_dict)

class BackendVLLM(BackendInterface):
def __init__(
Expand All @@ -187,7 +221,6 @@ def __init__(
instance_id=instance_id,
placement_group=placement_group,
node_id=node_id)
# multi-instance args
self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.output_processor.scheduler = self.engine.scheduler
Expand Down
1 change: 0 additions & 1 deletion llumnix/backends/vllm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
latency_mem=latency_mem, engine_args=engine_args)
self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.output_processor.scheduler = self.engine.scheduler
# multi-instance args
self.migration_config = migration_config
self.instance_id = instance_id
self.step_counter = Counter()
Expand Down
2 changes: 2 additions & 0 deletions llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
_C.SERVER.HOST = "localhost"
# Port number for the server
_C.SERVER.PORT = 8000
# Port number for the request output queue
_C.SERVER.REQUEST_OUTPUT_QUEUE_PORT = 1234
# Path to SSL key file for secure connections
_C.SERVER.SSL_KEYFILE = None
# Path to SSL certificate file for secure connections
Expand Down
23 changes: 10 additions & 13 deletions llumnix/entrypoints/llumnix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import asyncio
import ray

from ray.util.queue import Queue as RayQueue
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

from llumnix.llm_engine_manager import LLMEngineManager, MANAGER_ACTOR_NAME
from llumnix.llumlet.llumlet import Llumlet
from llumnix.backends.backend_interface import BackendType
from llumnix.logger import init_logger
from llumnix.utils import random_uuid
from llumnix.arg_utils import EngineManagerArgs
from llumnix.rpc.utils import get_open_zmq_ipc_path
from llumnix.server_info import ServerInfo
from llumnix.rpc.queue_server import QueueServer


logger = init_logger(__name__)
Expand Down Expand Up @@ -171,18 +171,17 @@ def init_llumlets(engine_manager_args: EngineManagerArgs,
llumlets.append(llumlet)
return instance_ids, llumlets

def init_request_output_queue() -> RayQueue:
# request_output_queue should be placed in the same node as the api server.
request_output_queue = RayQueue(actor_options={
"scheduling_strategy": NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(),
soft=False,)
})
def init_request_output_queue(server_info: ServerInfo) -> QueueServer:
rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port)
request_output_queue = QueueServer(rpc_path)
return request_output_queue

def init_llumnix_components(engine_manager_args: EngineManagerArgs,
engine_args,
node_id: str) -> Tuple[LLMEngineManager, List[Llumlet], RayQueue]:
node_id: str,
server_info: ServerInfo) -> Tuple[LLMEngineManager, List[Llumlet], QueueServer]:
request_output_queue = init_request_output_queue(server_info)

engine_manager = init_manager(engine_manager_args)
if engine_manager_args.disable_init_instance_by_manager:
instance_ids, llumlets = init_llumlets(engine_manager_args, engine_args, node_id)
Expand Down Expand Up @@ -212,6 +211,4 @@ def init_llumnix_components(engine_manager_args: EngineManagerArgs,
logger.info("Init Llumnix components done, {} instances are ready, instance_ids: {}."
.format(len(available_instance_ids), available_instance_ids))

request_output_queue = init_request_output_queue()

return engine_manager, available_instance_ids, available_llumlets, request_output_queue
39 changes: 22 additions & 17 deletions llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@

from llumnix.arg_utils import EngineManagerArgs
from llumnix.server_info import ServerInfo
from llumnix.entrypoints.llumnix_utils import (launch_ray_cluster, connect_to_ray_cluster,
is_gpu_available, init_llumnix_components)
from llumnix.entrypoints.llumnix_utils import (get_ip_address,
launch_ray_cluster, connect_to_ray_cluster,
is_gpu_available, init_llumnix_components)
from llumnix.logger import init_logger
from llumnix.utils import random_uuid
from llumnix.backends.vllm.utils import check_engine_args
from llumnix.rpc.queue_server import QueueServer
from llumnix.config import get_llumnix_config, LlumnixConfig

logger = init_logger("llumnix.api_server")
Expand All @@ -41,8 +43,8 @@
instances = {}
instance_num_requests: Dict[str, int] = {}
# request_output_queue could be None if initialzed in lifespan.
request_output_queue = None
server_id = None
request_output_queue: QueueServer = None
server_info = None
TIMEOUT_KEEP_ALIVE = 5 # seconds.
request_streams: Dict[str, AsyncStream] = {}
log_requests = None
Expand All @@ -53,23 +55,23 @@

async def _background_process_outputs():
while True:
qsize = await request_output_queue.actor.qsize.remote()
request_outputs = await request_output_queue.actor.get_nowait_batch.remote(qsize)
for request_output in request_outputs:
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in request_streams:
continue
request_streams[request_id].put(request_output)
if request_output.finished:
request_streams[request_id].finish()
del request_streams[request_id]
request_output = await request_output_queue.get()
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in request_streams:
continue
request_streams[request_id].put(request_output)
if request_output.finished:
request_streams[request_id].finish()
del request_streams[request_id]

# pylint: disable=unused-argument
@asynccontextmanager
async def lifespan(fastapi_app: FastAPI):
asyncio.create_task(request_output_queue.run_server_loop())
asyncio.create_task(_background_process_outputs())
yield
request_output_queue.cleanup()

app = FastAPI(lifespan=lifespan)

Expand All @@ -79,7 +81,6 @@ async def manager_generate(prompt, sampling_params, request_id) -> AsyncStream:
results_generator = AsyncStream(request_id)
request_streams[request_id] = results_generator
# This request's outputs will be put to the request_output_queue of this api server no matter which instance it's running in.
server_info = ServerInfo(server_id, request_output_queue)
# If manager is unavailable, request will be directly added to the llumlet held by api server.
global manager_available
try:
Expand Down Expand Up @@ -241,7 +242,9 @@ async def is_ready():
action='store_true',
default=None,
help='if launch ray cluster in api server')
parser.add_argument("--request-output-queue-port", type=int, default=None)
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")

parser = EngineManagerArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand All @@ -265,9 +268,11 @@ async def is_ready():
if is_gpu_available():
# Launch the Llumnix componets on current node.
server_id = random_uuid()
ip = get_ip_address()
server_info = ServerInfo(server_id, ip, args.request_output_queue_port)
node_id = ray.get_runtime_context().get_node_id()
engine_manager, instance_ids, llumlets, request_output_queue = \
init_llumnix_components(engine_manager_args, engine_args, node_id)
init_llumnix_components(engine_manager_args, engine_args, node_id, server_info)

for idx, ins_id in enumerate(instance_ids):
instances[ins_id] = llumlets[idx]
Expand Down
Loading

0 comments on commit c6ac5db

Please sign in to comment.