Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix request output loss during putting back to the api server #27

Merged
merged 4 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ lint: check_pylint_installed

.PHONY: test
test: check_pytest_installed
@pytest -q -x --ignore=third_party/ --disable-warnings
@pytest -x --ignore=third_party/ --disable-warnings

#################### pygloo install for gloo migration backend begin ####################

Expand Down
5 changes: 4 additions & 1 deletion llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def _process_model_outputs(
scheduled_seq_groups = new_scheduled_seq_groups
output[0].outputs = new_output
seq_group_metadata_list = new_seq_group_metadata_list
for ignored_seq_group in ignored_seq_groups:
server_info_list.append(ignored_seq_group.server_info)
request_outputs = super()._process_model_outputs(output, scheduled_seq_groups, ignored_seq_groups, seq_group_metadata_list)
# TODO(ZeldaHuang) Use LlumnixRequestOutput to store llumnix output args.
return request_outputs, server_info_list
Expand Down Expand Up @@ -149,7 +151,8 @@ def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs)
super().add_request(request_id, *args, **kwargs)
seq_group = self.scheduler.waiting[-1]
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, [seq_group.get_seqs()[0]], seq_group.sampling_params,
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _put_request_output_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
server_request_outputs = defaultdict(list)
Expand Down
6 changes: 4 additions & 2 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _get_instance_info(self) -> InstanceInfo:
instance_info.inference_type = self.running[-1].inference_type
# TODO(ZeldaHuang) adapt chunked-prefill
instance_info.num_batched_tokens = sum([seq_group.request_len for seq_group in self.running])\
if instance_info.inference_type == RequestInferenceType.PREFILL else len(instance_info.running_seq_lens)
if instance_info.inference_type == RequestInferenceType.PREFILL else len(instance_info.running_seq_lens)
instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.is_finished()]
return instance_info

Expand All @@ -205,8 +205,10 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
self.update_instance_info_callback(self._get_instance_info())
return seq_group_metadata_list, scheduler_outputs

@scheduler_lock
def add_seq_group(self, *args, **kwargs):
# The scheduler lock is mannually released in the end of LLMEngineLlumnix.add_request function.
# pylint: disable=R1732
self.scheduler_lock.acquire()
return super().add_seq_group(*args, **kwargs)

@scheduler_lock
Expand Down
1 change: 1 addition & 0 deletions llumnix/backends/vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType


class SequenceGroupLlumnix(SequenceGroup, LlumnixRequest):
def __init__(self, request_id, server_info, *args, **kwargs) -> None:
SequenceGroup.__init__(self, request_id, *args, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion tests/backends/vllm/test_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def test_llm_engine_add_requset():
engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True)
llm_engine = LLMEngineLlumnix.from_engine_args(engine_args, instance_id="0", migration_config=None, latency_mem=MagicMock(sepc=LatencyMemData))
sampling_params = SamplingParams(top_k=1, temperature=0, ignore_eos=True, max_tokens=100)
llm_engine.add_request("0", None,"prompt", sampling_params)
llm_engine.scheduler.scheduler_lock = MagicMock()
llm_engine.add_request("0", None, "prompt", sampling_params)
assert len(llm_engine.scheduler.waiting) == 1
assert llm_engine.scheduler.waiting[-1].request_id == "0"
assert isinstance(llm_engine.scheduler.waiting[-1], LlumnixRequest)
12 changes: 10 additions & 2 deletions tests/backends/vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,26 @@
from llumnix.backends.vllm.scheduler import SchedulerLlumnix
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix


class SchedulerLlumnixTest(SchedulerLlumnix):
def add_seq_group(self, *args, **kwargs):
ret = super().add_seq_group(*args, **kwargs)
self.scheduler_lock.release()
return ret


def initialize_scheduler(*,
max_num_seqs=1000,
max_token_budget=1000,
max_model_len=1000,
lora_config=None) -> SchedulerLlumnix:
lora_config=None) -> SchedulerLlumnixTest:
block_size = 4
scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs,
max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = SchedulerLlumnix(scheduler_config, cache_config, lora_config)
scheduler = SchedulerLlumnixTest(scheduler_config, cache_config, lora_config)
scheduler.update_instance_info_callback = MagicMock()
return scheduler

Expand Down
Loading