Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Optimize request output tokens putting back implementation to reduce overhead #45

Merged
merged 15 commits into from
Oct 11, 2024
2 changes: 1 addition & 1 deletion llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
raise NotImplementedError

@abstractmethod
def _start_engine_loop(self) -> None:
def _start_engine_step_loop(self) -> None:
"""Start step loop of backend engine.
"""
raise NotImplementedError
Expand Down
92 changes: 57 additions & 35 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from collections import defaultdict
import threading
import asyncio
import queue
import ray
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

from vllm.engine.llm_engine import LLMEngine
from vllm.core.scheduler import ScheduledSequenceGroup
Expand All @@ -42,30 +44,21 @@
logger = init_logger(__name__)


class AsyncPutQueueThread(threading.Thread):
class AsyncPutQueueActor:
def __init__(self, instance_id, output_queue_type: QueueType):
super().__init__()
self.instance_id = instance_id

self.request_output_queue_client: QueueClientBase \
= get_output_queue_client(output_queue_type)
self.request_output_queue_client: QueueClientBase = get_output_queue_client(output_queue_type)
self.engine_actor_handle = None
self.loop = asyncio.new_event_loop()
self.daemon = True

def run(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

async def _put_nowait_batch_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
async def put_nowait_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
if self.engine_actor_handle is None:
self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix")
tasks = []
for server_id, req_outputs in server_request_outputs.items():
server_info = server_info_dict[server_id]
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait_batch(req_outputs, server_info)))
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info)))
rets = await asyncio.gather(*tasks, return_exceptions=True)
for idx, ret in enumerate(rets):
if isinstance(ret, TimeoutError):
Expand All @@ -78,22 +71,38 @@ async def _put_nowait_batch_to_servers(self,
request_ids = [req_output.request_id for req_output in req_outputs]
self.engine_actor_handle.abort_request.remote(request_ids)

def put_nowait_batch_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
asyncio.run_coroutine_threadsafe(self._put_nowait_batch_to_servers(server_request_outputs, server_info_dict),
self.loop)


class LLMEngineLlumnix(LLMEngine):
def __init__(self, instance_id: str, output_queue_type: QueueType, *arg, **kwargs) -> None:
def __init__(self,
instance_id: str,
output_queue_type: QueueType,
placement_group: Optional[PlacementGroup],
node_id: Optional[str],
*arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.instance_id = instance_id
self.step_counter = Counter()
self.instance_info = None
# TODO(s5u13b): Reduce the overhead.
self.async_put_queue_thread = AsyncPutQueueThread(instance_id, output_queue_type)
self.async_put_queue_thread.start()
# Place the async put queue actor together with the instance.
if placement_group:
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
)
else:
scheduling_strategy = NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
)
self.put_queue_args_queue = queue.Queue()
self.put_queue_loop_thread = threading.Thread(
target=self._start_put_queue_loop, args=(), daemon=True, name="put_queue_loop"
)
self.async_put_queue_actor = ray.remote(
num_cpus=1,
scheduling_strategy=scheduling_strategy
)(AsyncPutQueueActor).remote(instance_id, output_queue_type)
self.put_queue_loop_thread.start()

# pylint: disable=W0221
@classmethod
Expand All @@ -105,7 +114,7 @@ def from_engine_args(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
instance_id: str = None,
placement_group: Optional[PlacementGroup] = None,
node_id: str = None,
node_id: Optional[str] = None,
latency_mem: Optional[LatencyMemData] = None
) -> "LLMEngineLlumnix":
"""Creates an LLM engine from the engine arguments."""
Expand All @@ -130,6 +139,8 @@ def from_engine_args(
engine = cls(
instance_id=instance_id,
output_queue_type=output_queue_type,
placement_group=placement_group,
node_id=node_id,
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
Expand Down Expand Up @@ -186,10 +197,12 @@ def step(self) -> None:
tot_blocks.extend(blocks)
tot_blocks = set(tot_blocks)
instance_info.num_blocks_last_running_request = len(tot_blocks)

if request_outputs:
self._put_request_outputs_to_server(request_outputs, server_infos)
self.put_queue_args_queue.put((request_outputs, server_infos))
self.instance_info = instance_info
num_request_outputs = len(request_outputs)

return num_request_outputs

def update_instance_info(self, instance_info: InstanceInfo) -> None:
# These fields are updated after step.
Expand All @@ -208,7 +221,13 @@ def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs)
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _put_request_outputs_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
def _start_put_queue_loop(self):
while True:
args = self.put_queue_args_queue.get()
request_outputs, server_infos = args
self._put_request_outputs_to_server(request_outputs, server_infos)

def _put_request_outputs_to_server(self, request_outputs: List[RequestOutput], server_infos: List[ServerInfo]) -> None:
server_request_outputs = defaultdict(list)
server_info_dict = {}
# Reorganize data in orther to put request output to queue in batch at one time.
Expand All @@ -217,7 +236,8 @@ def _put_request_outputs_to_server(self, request_outputs, server_infos: List[Ser
server_request_outputs[server_id].append(request_output)
if server_id not in server_info_dict:
server_info_dict[server_id] = server_info
self.async_put_queue_thread.put_nowait_batch_to_servers(server_request_outputs, server_info_dict)
# TODO(s5u13b): Reduce the cross-actor overhead.
self.async_put_queue_actor.put_nowait_to_servers.remote(server_request_outputs, server_info_dict)

class BackendVLLM(BackendInterface):
def __init__(
Expand Down Expand Up @@ -251,12 +271,12 @@ def __init__(
logger.info("engine ({}) current state {}".format(self.instance_id, self.state))

self._stop_event = threading.Event()
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
self.engine_step_loop_thread = threading.Thread(
target=self._start_engine_step_loop, args=(), daemon=True, name="engine_step_loop"
)
self._thread.start()
self.engine_step_loop_thread.start()

def _start_engine_loop(self) -> None:
def _start_engine_step_loop(self) -> None:
self._stop_event.clear()

with self.state_lock:
Expand All @@ -266,7 +286,9 @@ def _start_engine_loop(self) -> None:

while not self._stop_event.is_set():
try:
self.engine.step()
num_request_outputs = self.engine.step()
if num_request_outputs == 0:
time.sleep(0.01)
# pylint: disable=broad-except
except Exception as e:
logger.error("Error in engine loop: {}".format(e))
Expand Down
4 changes: 2 additions & 2 deletions llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# Port number for the server
_C.SERVER.PORT = 8000
# Queue type for request output queue
_C.SERVER.QUEUE_TYPE = "rayqueue"
_C.SERVER.QUEUE_TYPE = "zmq"
# Port number for the request output queue
_C.SERVER.REQUEST_OUTPUT_QUEUE_PORT = 1234
# Path to SSL key file for secure connections
Expand All @@ -42,7 +42,7 @@
# -----------------------------------------------------------------------------
_C.RAY = LC()
# Port number for the Ray cluster
_C.RAY.RAY_CLUSTER_PORT = 30050
_C.RAY.RAY_CLUSTER_PORT = 6379
# If True, launch Ray cluster in API server
_C.RAY.LAUNCH_RAY_CLUSTER = False

Expand Down
23 changes: 12 additions & 11 deletions llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@

async def _background_process_outputs():
while True:
request_output = await request_output_queue.get()
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in request_streams:
continue
request_streams[request_id].put(request_output)
if request_output.finished:
request_streams[request_id].finish()
del request_streams[request_id]
request_outputs = await request_output_queue.get()
for request_output in request_outputs:
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in request_streams:
continue
request_streams[request_id].put(request_output)
if request_output.finished:
request_streams[request_id].finish()
del request_streams[request_id]

# pylint: disable=unused-argument
@asynccontextmanager
Expand Down Expand Up @@ -180,11 +181,11 @@ async def generate_benchmark(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

start = time.time()

results_generator = await manager_generate(prompt, sampling_params, request_id)

per_token_latency = []
start = time.time()

# Non-streaming case
final_output = None
async for request_output in results_generator:
Expand Down
20 changes: 13 additions & 7 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from typing import List, Union, Iterable
import time
import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
Expand Down Expand Up @@ -86,7 +85,9 @@ def from_args(cls,
lifetime=lifetime)(cls).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=0,))
placement_group_bundle_index=0,
)
)
else:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
Expand All @@ -96,16 +97,21 @@ def from_args(cls,
lifetime=lifetime)(cls).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,))
soft=False,
)
)
else: # backend_type == backend_type.SIM_VLLM:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,))
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
)
)
llumlet = engine_class.remote(instance_id, output_queue_type, backend_type, migration_config, *args, **kwargs)
return llumlet

Expand Down
2 changes: 1 addition & 1 deletion llumnix/queue/queue_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@

class QueueClientBase(ABC):
@abstractmethod
async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo):
async def put_nowait(self, items: Iterable, server_info: ServerInfo):
raise NotImplementedError
4 changes: 2 additions & 2 deletions llumnix/queue/ray_queue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@
from llumnix.queue.queue_client_base import QueueClientBase

class RayQueueClient(QueueClientBase):
async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo):
async def put_nowait(self, items: Iterable, server_info: ServerInfo):
output_queue = server_info.request_output_queue
return await output_queue.actor.put_nowait_batch.remote(items)
return await output_queue.actor.put_nowait.remote(items)
8 changes: 4 additions & 4 deletions llumnix/queue/zmq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from llumnix.server_info import ServerInfo

from llumnix.queue.zmq_utils import (RPC_GET_DATA_TIMEOUT_MS, RPC_SOCKET_LIMIT_CUTOFF, RPC_ZMQ_HWM, RPC_SUCCESS_STR,
RPCClientClosedError, RPC_REQUEST_TYPE, RPCUtilityRequest, RPCPutNoWaitBatchQueueRequest,
get_open_zmq_ipc_path)
RPCClientClosedError, RPC_REQUEST_TYPE, RPCUtilityRequest, RPCPutNoWaitQueueRequest,
get_open_zmq_ipc_path)

logger = init_logger(__name__)

Expand Down Expand Up @@ -104,9 +104,9 @@ async def wait_for_server_rpc(self,
rpc_path=rpc_path,
error_message="Unable to start RPC Server")

async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo):
async def put_nowait(self, items: Iterable, server_info: ServerInfo):
rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port)
await self._send_one_way_rpc_request(
request=RPCPutNoWaitBatchQueueRequest(items=items),
request=RPCPutNoWaitQueueRequest(items=items),
rpc_path=rpc_path,
error_message="Unable to put items into queue.")
10 changes: 5 additions & 5 deletions llumnix/queue/zmq_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import cloudpickle

from llumnix.queue.zmq_utils import (RPC_ZMQ_HWM, RPC_SUCCESS_STR, RPC_SOCKET_LIMIT_CUTOFF,
RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest)
RPCPutNoWaitQueueRequest, RPCUtilityRequest)
from llumnix.logger import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -110,18 +110,18 @@ def _make_handler_coro(self, identity,
request = cloudpickle.loads(message)
if request == RPCUtilityRequest.IS_SERVER_READY:
return self._is_server_ready(identity)
if isinstance(request, RPCPutNoWaitBatchQueueRequest):
return self._put_nowait_batch(identity, request)
if isinstance(request, RPCPutNoWaitQueueRequest):
return self._put_nowait(identity, request)

raise ValueError(f"Unknown RPCRequest type: {request}")

async def _is_server_ready(self, identity):
await self.socket.send_multipart(
[identity, cloudpickle.dumps(RPC_SUCCESS_STR)])

async def _put_nowait_batch(self, identity, put_nowait_batch_queue_request: RPCPutNoWaitBatchQueueRequest):
async def _put_nowait(self, identity, put_nowait_queue_request: RPCPutNoWaitQueueRequest):
try:
self.put_nowait_batch(put_nowait_batch_queue_request.items)
self.put_nowait(put_nowait_queue_request.items)
await self.socket.send_multipart(
[identity, cloudpickle.dumps(RPC_SUCCESS_STR)])
# pylint: disable=W0703
Expand Down
4 changes: 2 additions & 2 deletions llumnix/queue/zmq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
RPC_SUCCESS_STR = "SUCCESS"

@dataclass
class RPCPutNoWaitBatchQueueRequest:
class RPCPutNoWaitQueueRequest:
items: List[Any] = None

class RPCUtilityRequest(Enum):
IS_SERVER_READY = 1

# pylint: disable=C0103
RPC_REQUEST_TYPE = Union[RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest]
RPC_REQUEST_TYPE = Union[RPCPutNoWaitQueueRequest, RPCUtilityRequest]

class RPCClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
Expand Down
6 changes: 2 additions & 4 deletions llumnix/server_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ def __init__(self,
request_output_queue_port: int) -> None:
self.server_id = server_id
self.output_queue_type = output_queue_type

if output_queue_type == QueueType.RAYQUEUE:
assert request_output_queue is not None and hasattr(request_output_queue, "queue")
self.request_output_queue = request_output_queue.queue if hasattr(request_output_queue, "queue") else None

assert request_output_queue is not None
self.request_output_queue = request_output_queue.queue if output_queue_type == QueueType.RAYQUEUE else None
self.request_output_queue_ip = request_output_queue_ip
self.request_output_queue_port = request_output_queue_port
Loading
Loading