Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Use zeromq to put request output tokens back to the api server #28

Merged
merged 40 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3c92ed6
refactor
ZeldaHuang Aug 28, 2024
59f3a63
update
ZeldaHuang Aug 30, 2024
41fe8b2
fix
ZeldaHuang Aug 30, 2024
86050fc
remove todo
ZeldaHuang Sep 2, 2024
036cd3d
update
ZeldaHuang Sep 2, 2024
69284ed
add todo
ZeldaHuang Sep 2, 2024
9a7d551
update
ZeldaHuang Sep 2, 2024
cb344fc
First version of zeromq queue
s5u13b Sep 3, 2024
5e2af51
Merge branch 'main' into misc
s5u13b Sep 3, 2024
7bd7896
Fix errors
s5u13b Sep 3, 2024
de5f30f
Fix await
s5u13b Sep 3, 2024
193c966
Fixing request leaking bug
s5u13b Sep 3, 2024
b2608ea
Fix request output leaking bug
s5u13b Sep 4, 2024
685dcc8
pylint
s5u13b Sep 4, 2024
d5345c9
Fix
s5u13b Sep 4, 2024
725c9da
Fix pytest of engine and scheduler
s5u13b Sep 4, 2024
071f27c
Merge branch 'fix-request-output-leaking' into zeromq
s5u13b Sep 4, 2024
8ac92cb
Minors
s5u13b Sep 4, 2024
94b322b
Fix
s5u13b Sep 4, 2024
45d623a
Testing
s5u13b Sep 4, 2024
bb7be5a
Testing
s5u13b Sep 5, 2024
2814da0
Clean codes
s5u13b Sep 5, 2024
2bab685
pylint
s5u13b Sep 5, 2024
081e9b6
Add and fix unittest of queue
s5u13b Sep 5, 2024
6f4c133
Merge branch 'main' into zeromq
s5u13b Sep 5, 2024
4ba5d02
Merge branch 'main' into zeromq
s5u13b Sep 5, 2024
b346046
Merge branch 'zeromq' of https://github.com/AlibabaPAI/llumnix into z…
s5u13b Sep 5, 2024
7a65902
pylint
s5u13b Sep 5, 2024
7e83900
Merge branch 'main' into zeromq
s5u13b Sep 5, 2024
937de14
pylint
s5u13b Sep 5, 2024
ce82ba8
pylint
s5u13b Sep 5, 2024
7fbbead
pylint
s5u13b Sep 5, 2024
cd4f5b0
Rename AsyncActor
s5u13b Sep 5, 2024
dfdfb22
Fix default manager args
s5u13b Sep 5, 2024
628540c
Use thread to asynchronous put request outputs to servers
s5u13b Sep 9, 2024
85ef719
Add TODO
s5u13b Sep 9, 2024
b095815
Minors
s5u13b Sep 9, 2024
2d61672
Merge branch 'main' into zeromq
s5u13b Sep 9, 2024
5a4308b
Remove annotation
s5u13b Sep 9, 2024
f95c042
Fix request output queue arg
s5u13b Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 57 additions & 24 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple
from collections import defaultdict
import threading
import asyncio
import ray
from ray.util.queue import Queue as RayQueue
from ray.util.placement_group import PlacementGroup

from vllm.engine.llm_engine import LLMEngine
Expand All @@ -32,20 +32,61 @@
from llumnix.backends.backend_interface import BackendInterface
from llumnix.backends.vllm.scheduler import SchedulerLlumnix
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix
from llumnix.backends.vllm.utils import detect_unsupported_feature
from llumnix.backends.profiling import LatencyMemData
from llumnix.server_info import ServerInfo
from llumnix.internal_config import MigrationConfig

from llumnix.rpc.queue_client import QueueClient

logger = init_logger(__name__)


class AsyncPutQueueThread(threading.Thread):
def __init__(self, instance_id):
super().__init__()
self.instance_id = instance_id
self.request_output_queue_client = QueueClient()
self.engine_actor_handle = None
self.loop = asyncio.new_event_loop()
self.daemon = True

def run(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

async def _put_nowait_batch_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
if self.engine_actor_handle is None:
self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix")
tasks = []
for server_id, req_outputs in server_request_outputs.items():
server_info = server_info_dict[server_id]
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait_batch(req_outputs, server_info)))
rets = await asyncio.gather(*tasks, return_exceptions=True)
for idx, ret in enumerate(rets):
if isinstance(ret, TimeoutError):
server_id = list(server_request_outputs.keys())[idx]
logger.info("Server {} is dead".format(server_id))
req_outputs = list(server_request_outputs.values())[idx]
request_ids = [req_output.request_id for req_output in req_outputs]
self.engine_actor_handle.abort_request.remote(request_ids)

def put_nowait_batch_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
asyncio.run_coroutine_threadsafe(self._put_nowait_batch_to_servers(server_request_outputs, server_info_dict),
self.loop)


class LLMEngineLlumnix(LLMEngine):
def __init__(self, instance_id: str, *arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.instance_id = instance_id
self.step_counter = Counter()
self.instance_info = None
# TODO(s5u13b): Reduce the overhead.
self.async_put_queue_thread = AsyncPutQueueThread(instance_id)
self.async_put_queue_thread.start()

# pylint: disable=W0221
@classmethod
Expand All @@ -61,7 +102,6 @@ def from_engine_args(
) -> "LLMEngineLlumnix":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
detect_unsupported_feature(engine_args)
engine_config = engine_args.create_engine_config()
engine_config.parallel_config.placement_group = placement_group
# Initialize the cluster and specify the executor class.
Expand Down Expand Up @@ -97,7 +137,7 @@ def _process_model_outputs(
) -> Tuple[List[RequestOutput], List[ServerInfo]]:
# ensure scheduled_seq_groups matching output
with self.scheduler.scheduler_lock:
server_info_list = []
server_infos = []
if output:
new_output = []
new_scheduled_seq_groups = []
Expand All @@ -108,18 +148,18 @@ def _process_model_outputs(
new_scheduled_seq_groups.append(scheduled_seq_group)
new_seq_group_metadata_list.append(seq_group_meta)
new_output.append(seq_group_output)
server_info_list.append(seq_group.server_info)
server_infos.append(seq_group.server_info)
scheduled_seq_groups = new_scheduled_seq_groups
output[0].outputs = new_output
seq_group_metadata_list = new_seq_group_metadata_list
for ignored_seq_group in ignored_seq_groups:
server_info_list.append(ignored_seq_group.server_info)
server_infos.append(ignored_seq_group.server_info)
request_outputs = super()._process_model_outputs(output, scheduled_seq_groups, ignored_seq_groups, seq_group_metadata_list)
# TODO(ZeldaHuang) Use LlumnixRequestOutput to store llumnix output args.
return request_outputs, server_info_list
# TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args.
return request_outputs, server_infos

def step(self) -> None:
output_list, server_info_list = super().step()
request_outputs, server_infos = super().step()

instance_info: InstanceInfo = self.instance_info
instance_info.instance_id = self.instance_id
Expand All @@ -135,7 +175,8 @@ def step(self) -> None:
tot_blocks = set(tot_blocks)
instance_info.num_blocks_last_running_request = len(tot_blocks)

self._put_request_output_to_server(output_list, server_info_list)
if request_outputs:
self._put_request_outputs_to_server(request_outputs, server_infos)
self.instance_info = instance_info

def update_instance_info(self, instance_info: InstanceInfo) -> None:
Expand All @@ -155,23 +196,16 @@ def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs)
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _put_request_output_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
def _put_request_outputs_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
server_request_outputs = defaultdict(list)
server_queue: Dict[str, RayQueue] = {}
server_info_dict = {}
# Reorganize data in orther to put request output to queue in batch at one time.
for request_output, server_info in zip(request_outputs, server_infos):
server_id = server_info.server_id
request_output_queue = server_info.request_output_queue
server_request_outputs[server_id].append(request_output)
if server_id not in server_queue:
server_queue[server_id] = request_output_queue
for server_id, req_outputs in server_request_outputs.items():
try:
server_queue[server_id].actor.put_nowait_batch.remote(req_outputs)
except ray.exceptions.RayActorError:
logger.info("Server {} is dead".format(server_id))
request_ids = [req_output.request_id for req_output in req_outputs]
self.abort_request(request_ids)
if server_id not in server_info_dict:
server_info_dict[server_id] = server_info
self.async_put_queue_thread.put_nowait_batch_to_servers(server_request_outputs, server_info_dict)

class BackendVLLM(BackendInterface):
def __init__(
Expand All @@ -187,7 +221,6 @@ def __init__(
instance_id=instance_id,
placement_group=placement_group,
node_id=node_id)
# multi-instance args
self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.output_processor.scheduler = self.engine.scheduler
Expand Down
1 change: 0 additions & 1 deletion llumnix/backends/vllm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
latency_mem=latency_mem, engine_args=engine_args)
self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.output_processor.scheduler = self.engine.scheduler
# multi-instance args
self.migration_config = migration_config
self.instance_id = instance_id
self.step_counter = Counter()
Expand Down
23 changes: 10 additions & 13 deletions llumnix/entrypoints/llumnix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import asyncio
import ray

from ray.util.queue import Queue as RayQueue
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

from llumnix.llm_engine_manager import LLMEngineManager, MANAGER_ACTOR_NAME
from llumnix.llumlet.llumlet import Llumlet
from llumnix.backends.backend_interface import BackendType
from llumnix.logger import init_logger
from llumnix.utils import random_uuid
from llumnix.arg_utils import EngineManagerArgs
from llumnix.rpc.utils import get_open_zmq_ipc_path
from llumnix.server_info import ServerInfo
from llumnix.rpc.queue_server import QueueServer


logger = init_logger(__name__)
Expand Down Expand Up @@ -171,18 +171,17 @@ def init_llumlets(engine_manager_args: EngineManagerArgs,
llumlets.append(llumlet)
return instance_ids, llumlets

def init_request_output_queue() -> RayQueue:
# request_output_queue should be placed in the same node as the api server.
request_output_queue = RayQueue(actor_options={
"scheduling_strategy": NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(),
soft=False,)
})
def init_request_output_queue(server_info: ServerInfo) -> QueueServer:
rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port)
request_output_queue = QueueServer(rpc_path)
return request_output_queue

def init_llumnix_components(engine_manager_args: EngineManagerArgs,
engine_args,
node_id: str) -> Tuple[LLMEngineManager, List[Llumlet], RayQueue]:
node_id: str,
server_info: ServerInfo) -> Tuple[LLMEngineManager, List[Llumlet], QueueServer]:
request_output_queue = init_request_output_queue(server_info)

engine_manager = init_manager(engine_manager_args)
if engine_manager_args.disable_init_instance_by_manager:
instance_ids, llumlets = init_llumlets(engine_manager_args, engine_args, node_id)
Expand Down Expand Up @@ -212,6 +211,4 @@ def init_llumnix_components(engine_manager_args: EngineManagerArgs,
logger.info("Init Llumnix components done, {} instances are ready, instance_ids: {}."
.format(len(available_instance_ids), available_instance_ids))

request_output_queue = init_request_output_queue()

return engine_manager, available_instance_ids, available_llumlets, request_output_queue
39 changes: 22 additions & 17 deletions llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@

from llumnix.arg_utils import EngineManagerArgs
from llumnix.server_info import ServerInfo
from llumnix.entrypoints.llumnix_utils import (launch_ray_cluster, connect_to_ray_cluster,
is_gpu_available, init_llumnix_components)
from llumnix.entrypoints.llumnix_utils import (get_ip_address,
launch_ray_cluster, connect_to_ray_cluster,
is_gpu_available, init_llumnix_components)
from llumnix.logger import init_logger
from llumnix.utils import random_uuid
from llumnix.backends.vllm.utils import check_engine_args
from llumnix.rpc.queue_server import QueueServer
from llumnix.config import get_llumnix_config, LlumnixConfig

logger = init_logger("llumnix.api_server")
Expand All @@ -41,8 +43,8 @@
instances = {}
instance_num_requests: Dict[str, int] = {}
# request_output_queue could be None if initialzed in lifespan.
request_output_queue = None
server_id = None
request_output_queue: QueueServer = None
server_info = None
TIMEOUT_KEEP_ALIVE = 5 # seconds.
request_streams: Dict[str, AsyncStream] = {}
log_requests = None
Expand All @@ -53,23 +55,23 @@

async def _background_process_outputs():
while True:
qsize = await request_output_queue.actor.qsize.remote()
request_outputs = await request_output_queue.actor.get_nowait_batch.remote(qsize)
for request_output in request_outputs:
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in request_streams:
continue
request_streams[request_id].put(request_output)
if request_output.finished:
request_streams[request_id].finish()
del request_streams[request_id]
request_output = await request_output_queue.get()
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in request_streams:
continue
request_streams[request_id].put(request_output)
if request_output.finished:
request_streams[request_id].finish()
del request_streams[request_id]

# pylint: disable=unused-argument
@asynccontextmanager
async def lifespan(fastapi_app: FastAPI):
asyncio.create_task(request_output_queue.run_server_loop())
asyncio.create_task(_background_process_outputs())
yield
request_output_queue.cleanup()

app = FastAPI(lifespan=lifespan)

Expand All @@ -79,7 +81,6 @@ async def manager_generate(prompt, sampling_params, request_id) -> AsyncStream:
results_generator = AsyncStream(request_id)
request_streams[request_id] = results_generator
# This request's outputs will be put to the request_output_queue of this api server no matter which instance it's running in.
server_info = ServerInfo(server_id, request_output_queue)
# If manager is unavailable, request will be directly added to the llumlet held by api server.
global manager_available
try:
Expand Down Expand Up @@ -238,7 +239,9 @@ async def is_ready():
action='store_true',
default=None,
help='if launch ray cluster in api server')
parser.add_argument("--request-output-queue-port", type=int, default=1234)
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")

parser = EngineManagerArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand All @@ -262,9 +265,11 @@ async def is_ready():
if is_gpu_available():
# Launch the Llumnix componets on current node.
server_id = random_uuid()
ip = get_ip_address()
server_info = ServerInfo(server_id, ip, args.request_output_queue_port)
node_id = ray.get_runtime_context().get_node_id()
engine_manager, instance_ids, llumlets, request_output_queue = \
init_llumnix_components(engine_manager_args, engine_args, node_id)
init_llumnix_components(engine_manager_args, engine_args, node_id, server_info)

for idx, ins_id in enumerate(instance_ids):
instances[ins_id] = llumlets[idx]
Expand Down
Loading
Loading