Skip to content

Commit

Permalink
[CI] Add simulator e2e test and tensor parallelism unit test and e2e …
Browse files Browse the repository at this point in the history
…test (#112)
  • Loading branch information
s5u13b authored Feb 25, 2025
1 parent 5743023 commit 6c34033
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 65 deletions.
4 changes: 4 additions & 0 deletions llumnix/backends/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def _extract_data(self, row):
"""Extract the profiling results from a row of the profiling CSV file."""
# assert pp==1
profiling_data = row["profiling_data"].strip('"()"').split(",")
if profiling_data[0].strip() == '' or profiling_data[1].strip() == 'None' or profiling_data[2] == '0' or profiling_data[3] == '0':
return None, None, None, None
inference_type = RequestInferenceType.PREFILL if profiling_data[0] == "'prefill'" else RequestInferenceType.DECODE
batch_size = _pad_to_alignment(int(profiling_data[1]), 8)
tot_seq_len =_pad_to_alignment(int(profiling_data[2]), 8)
Expand All @@ -166,6 +168,8 @@ def update_from_instance_log(self, file_name: str, model: str, parallel_config:
self.results[model] = ProfilingResult(model, {})
for _, row in df.iterrows():
stage_latencies, inference_type, batch_size, tot_seq_len = self._extract_data(row)
if not stage_latencies:
continue
self.results[model].add_latency_result(parallel_config, inference_type, batch_size, tot_seq_len, stage_latencies)

def materialize(self):
Expand Down
4 changes: 4 additions & 0 deletions llumnix/backends/vllm/sim_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,7 @@ async def execute_model_async(
output = CompletionSequenceGroupOutput(samples, None)
sampler_outputs.append(output)
return [SamplerOutput(outputs=sampler_outputs)]

async def send_blocks(self, blocks_len) -> None:
migration_latency = (self.cache_block_size * blocks_len) / self.migration_bandwidth
await asyncio.sleep(migration_latency)
49 changes: 28 additions & 21 deletions tests/e2e_test/test_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,20 @@ def get_markdown_data(key: str, head_name: str):
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for simple benchmark")
@pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B'])
@pytest.mark.parametrize("launch_mode", ['global', 'local'])
@pytest.mark.parametrize("enable_pd_disagg", [True, False])
async def test_simple_benchmark(ray_env, shutdown_llumnix_service, model, launch_mode, enable_pd_disagg):
@pytest.mark.parametrize("enable_pd_disagg", [False, True])
@pytest.mark.parametrize("enable_simulator", [False, True])
async def test_simple_benchmark(ray_env, shutdown_llumnix_service, model, launch_mode, enable_pd_disagg, enable_simulator):
if enable_simulator and enable_pd_disagg:
pytest.skip("When enabling simulator, prefill-decode disaggregation is not tested.")

if launch_mode == 'local':
num_prompts = 500 if not enable_pd_disagg else 50
else:
num_prompts = 50 if not enable_pd_disagg else 50

if enable_simulator:
num_prompts = 50

ip = get_ip_address()
base_port = 37037
ip_ports = []
Expand All @@ -85,35 +92,36 @@ async def test_simple_benchmark(ray_env, shutdown_llumnix_service, model, launch
ip_port = f"{ip}:{port}"
ip_ports.append(ip_port)
launch_command = generate_launch_command(result_filename=str(base_port+i)+".out",
launch_ray_cluster=False,
ip=ip,
port=port,
model=model,
enable_pd_disagg=enable_pd_disagg,
instance_type="prefill")
launch_ray_cluster=False,
ip=ip,
port=port,
model=model,
enable_pd_disagg=enable_pd_disagg,
instance_type="prefill")
subprocess.run(launch_command, shell=True, check=True)
for i in range(device_count//2):
port = base_port+i+device_count//2
ip_port = f"{ip}:{port}"
ip_ports.append(ip_port)
launch_command = generate_launch_command(result_filename=str(base_port+i)+".out",
launch_ray_cluster=False,
ip=ip,
port=port,
model=model,
enable_pd_disagg=enable_pd_disagg,
instance_type="decode")
launch_ray_cluster=False,
ip=ip,
port=port,
model=model,
enable_pd_disagg=enable_pd_disagg,
instance_type="decode")
subprocess.run(launch_command, shell=True, check=True)
else:
for i in range(device_count):
port = base_port+i
ip_port = f"{ip}:{port}"
ip_ports.append(ip_port)
launch_command = generate_launch_command(result_filename=str(base_port+i)+".out",
launch_ray_cluster=False,
ip=ip,
port=port,
model=model)
launch_ray_cluster=False,
ip=ip,
port=port,
model=model,
enable_simulator=enable_simulator)
subprocess.run(launch_command, shell=True, check=True)
else: # global
device_count = torch.cuda.device_count()
Expand All @@ -125,9 +133,8 @@ async def test_simple_benchmark(ray_env, shutdown_llumnix_service, model, launch
ip=ip,
port=base_port,
model=model,
enable_pd_disagg=enable_pd_disagg)
# pylint: disable=subprocess-run-check
subprocess.run('ray start --head', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
enable_pd_disagg=enable_pd_disagg,
enable_simulator=enable_simulator)
subprocess.run(serve_command, shell=True, check=True)
wait_for_llumnix_service_ready(ip_ports)

Expand Down
25 changes: 15 additions & 10 deletions tests/e2e_test/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,22 @@ def get_instance_num_blocks():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench")
@pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B'])
@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo', 'nccl'])
@pytest.mark.parametrize("migrated_request_status", ['running', 'waiting'])
async def test_migration_benchmark(ray_env, shutdown_llumnix_service, model, migration_backend, migrated_request_status):
if migrated_request_status == 'waiting' and migration_backend != 'rayrpc':
pytest.skip("When the migrated request status is waiting, only test the rayrpc migration backend.")

request_migration_policy = 'SR' if migrated_request_status == 'running' else 'FCW'
@pytest.mark.parametrize("migration_request_status", ['running', 'waiting'])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
async def test_migration_benchmark(ray_env, shutdown_llumnix_service, model, migration_backend, migration_request_status, tensor_parallel_size):
if migration_request_status == 'waiting' and migration_backend != 'gloo':
pytest.skip("When the migrated request status is waiting, only test the gloo migration backend.")
if tensor_parallel_size == 2 and migration_backend != 'gloo':
pytest.skip("When the tensor parallel size is 2, only test the gloo migration backend.")

request_migration_policy = 'SR' if migration_request_status == 'running' else 'FCW'
ip = get_ip_address()
base_port = 37037
ip_ports = []
instance_output_logs = []
device_count = torch.cuda.device_count()
for i in range(device_count):
num_instances = device_count // tensor_parallel_size
for i in range(num_instances):
port = base_port + i
ip_ports.append(f"{ip}:{base_port+i}")
output_log = f"{base_port+i}.out"
Expand All @@ -117,7 +121,8 @@ async def test_migration_benchmark(ray_env, shutdown_llumnix_service, model, mig
model=model,
dispatch_policy="flood",
migration_backend=migration_backend,
request_migration_policy=request_migration_policy)
request_migration_policy=request_migration_policy,
tensor_parallel_size=tensor_parallel_size)
subprocess.run(launch_command, shell=True, check=True)

wait_for_llumnix_service_ready(ip_ports)
Expand All @@ -130,7 +135,7 @@ def run_bench_command(command):
return process

tasks = []
for i in range(device_count // 2):
for i in range(num_instances // 2):
bench_command = generate_bench_command(
ip_ports=f"{ip}:{base_port + i}",
model=model,
Expand Down Expand Up @@ -162,7 +167,7 @@ def run_bench_command(command):

assert instance_num_blocks_list_before_bench == instance_num_blocks_list_after_bench

if migrated_request_status == 'running':
if migration_request_status == 'running' and tensor_parallel_size == 1:
average_speed = parse_instance_log_file(instance_output_logs)
sorted_keys = sorted(average_speed.keys(), key=lambda x: float(x.split()[0])*1024 if 'GB' in x else float(x.split()[0]))
data = [
Expand Down
16 changes: 12 additions & 4 deletions tests/e2e_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def generate_launch_command(result_filename: str = "",
request_migration_policy: str = 'SR',
max_num_batched_tokens: int = 16000,
enable_pd_disagg: bool = False,
instance_type: str = "no_constraints"):
instance_type: str = "no_constraints",
tensor_parallel_size: int = 1,
enable_simulator: bool = False):
command = (
f"RAY_DEDUP_LOGS=0 HEAD_NODE_IP={HEAD_NODE_IP} HEAD_NODE=1 "
f"nohup python -u -m llumnix.entrypoints.vllm.api_server "
Expand All @@ -51,12 +53,14 @@ def generate_launch_command(result_filename: str = "",
f"--request-migration-policy {request_migration_policy} "
f"--migration-backend {migration_backend} "
f"--migration-buffer-blocks 32 "
f"--tensor-parallel-size 1 "
f"--tensor-parallel-size {tensor_parallel_size} "
f"--request-output-queue-port {1234+port} "
f"{'--launch-ray-cluster ' if launch_ray_cluster else ''}"
f"{'--enable-pd-disagg ' if enable_pd_disagg else ''}"
f"--instance-type {instance_type} "
f"--max-num-batched-tokens {max_num_batched_tokens} "
f"{'--simulator-mode ' if enable_simulator else ''}"
f"{'--profiling-result-file-path /mnt/model/simulator/Qwen-7B.pkl' if enable_simulator else ''}"
f"{'> instance_'+result_filename if len(result_filename)> 0 else ''} 2>&1 &"
)
return command
Expand All @@ -68,12 +72,13 @@ def generate_serve_command(result_filename: str = "",
migration_backend = "gloo",
model = "facebook/opt-125m",
max_model_len: int = 4096,
log_instance_info: bool = True,
log_instance_info: bool = False,
log_request_timestamps: bool = True,
request_migration_policy: str = 'SR',
max_num_batched_tokens: int = 16000,
enable_pd_disagg: bool = False,
pd_ratio: str = "1:1"):
pd_ratio: str = "1:1",
enable_simulator: bool = False):
command = (
f"RAY_DEDUP_LOGS=0 "
f"nohup python -u -m llumnix.entrypoints.vllm.serve "
Expand All @@ -97,6 +102,9 @@ def generate_serve_command(result_filename: str = "",
f"--pd-ratio {pd_ratio} "
f"--enable-port-increment "
f"{'--enable-pd-disagg ' if enable_pd_disagg else ''}"
f"{'--simulator-mode ' if enable_simulator else ''}"
f"--max-instances 4 "
f"{'--profiling-result-file-path /mnt/model/simulator/Qwen-7B.pkl' if enable_simulator else ''}"
f"{'> instance_'+result_filename if len(result_filename)> 0 else ''} 2>&1 &"
)
return command
Expand Down
Loading

0 comments on commit 6c34033

Please sign in to comment.