Skip to content

Commit

Permalink
feat: support config file for vllm runtime (#780)
Browse files Browse the repository at this point in the history
**Reason for Change**:

support to load config from file and overwrite runtime args.

**Requirements**

- [x] added unit tests and e2e tests (if applicable).

---------

Signed-off-by: jerryzhuang <zhuangqhc@gmail.com>
  • Loading branch information
zhuangqh authored Dec 17, 2024
1 parent 7da6586 commit 42f9ebc
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 67 deletions.
19 changes: 19 additions & 0 deletions charts/kaito/workspace/templates/inference-params.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
apiVersion: v1
kind: ConfigMap
metadata:
name: inference-params-template
namespace: {{ .Release.Namespace }}
data:
inference_config.yaml: |
# Maximum number of steps to find the max available seq len fitting in the GPU memory.
max_probe_steps: 6
vllm:
cpu-offload-gb: 0
gpu-memory-utilization: 0.95
swap-space: 4
# max-seq-len-to-capture: 8192
# num-scheduler-steps: 1
# enable-chunked-prefill: false
# see https://docs.vllm.ai/en/stable/models/engine_args.html for more options.
187 changes: 126 additions & 61 deletions presets/workspace/inference/vllm/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import logging
import gc
import os
from typing import Callable, Optional, List
import argparse
from typing import Callable, Optional, List, Any
import yaml
from dataclasses import dataclass

import uvloop
import torch
Expand All @@ -21,33 +24,88 @@
format='%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s',
datefmt='%m-%d %H:%M:%S')

def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
local_rank = int(os.environ.get("LOCAL_RANK",
0)) # Default to 0 if not set
port = 5000 + local_rank # Adjust port based on local rank

server_default_args = {
"disable_frontend_multiprocessing": False,
"port": port,
}
parser.set_defaults(**server_default_args)

# See https://docs.vllm.ai/en/stable/models/engine_args.html for more args
engine_default_args = {
"model": "/workspace/vllm/weights",
"cpu_offload_gb": 0,
"gpu_memory_utilization": 0.95,
"swap_space": 4,
"disable_log_stats": False,
"uvicorn_log_level": "error"
}
parser.set_defaults(**engine_default_args)

# KAITO only args
# They should start with "kaito-" prefix to avoid conflict with vllm args
parser.add_argument("--kaito-adapters-dir", type=str, default="/mnt/adapter", help="Directory where adapters are stored in KAITO preset.")

return parser
class KAITOArgumentParser(argparse.ArgumentParser):
vllm_parser = FlexibleArgumentParser(description="vLLM serving server")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Initialize vllm parser
self.vllm_parser = api_server.make_arg_parser(self.vllm_parser)
self._reset_vllm_defaults()

# KAITO only args
# They should start with "kaito-" prefix to avoid conflict with vllm args
self.add_argument("--kaito-adapters-dir", type=str, default="/mnt/adapter", help="Directory where adapters are stored in KAITO preset.")
self.add_argument("--kaito-config-file", type=str, default="", help="Additional args for KAITO preset.")
self.add_argument("--kaito-max-probe-steps", type=int, default=6, help="Maximum number of steps to find the max available seq len fitting in the GPU memory.")

def _reset_vllm_defaults(self):
local_rank = int(os.environ.get("LOCAL_RANK",
0)) # Default to 0 if not set
port = 5000 + local_rank # Adjust port based on local rank

server_default_args = {
"disable_frontend_multiprocessing": False,
"port": port,
}
self.vllm_parser.set_defaults(**server_default_args)

# See https://docs.vllm.ai/en/stable/models/engine_args.html for more args
engine_default_args = {
"model": "/workspace/vllm/weights",
"cpu_offload_gb": 0,
"gpu_memory_utilization": 0.95,
"swap_space": 4,
"disable_log_stats": False,
"uvicorn_log_level": "error"
}
self.vllm_parser.set_defaults(**engine_default_args)

def parse_args(self, *args, **kwargs):
args = super().parse_known_args(*args, **kwargs)
kaito_args = args[0]
runtime_args = args[1] # Remaining args

# Load KAITO config
if kaito_args.kaito_config_file:
file_config = KaitoConfig.from_yaml(kaito_args.kaito_config_file)
if kaito_args.kaito_max_probe_steps is None:
kaito_args.kaito_max_probe_steps = file_config.max_probe_steps

for key, value in file_config.vllm.items():
runtime_args.append(f"--{key}")
runtime_args.append(str(value))

vllm_args = self.vllm_parser.parse_args(runtime_args, **kwargs)
# Merge KAITO and vLLM args
return argparse.Namespace(**vars(kaito_args), **vars(vllm_args))

def print_help(self, file=None):
super().print_help(file)
print("\norignal vLLM server arguments:\n")
self.vllm_parser.print_help(file)

@dataclass
class KaitoConfig:
# Extra arguments for the vllm serving server, will be forwarded to the vllm server.
# This should be in key-value format.
vllm: dict[str, Any]

# Maximum number of steps to find the max available seq len fitting in the GPU memory.
max_probe_steps: int

@staticmethod
def from_yaml(yaml_file: str) -> 'KaitoConfig':
with open(yaml_file, 'r') as file:
config_data = yaml.safe_load(file)
return KaitoConfig(
vllm=config_data.get('vllm', {}),
max_probe_steps=config_data.get('max_probe_steps', 6)
)

def to_yaml(self) -> str:
return yaml.dump(self.__dict__)

def load_lora_adapters(adapters_dir: str) -> Optional[LoRAModulePath]:
lora_list: List[LoRAModulePath] = []
Expand Down Expand Up @@ -130,53 +188,60 @@ def is_context_length_safe(executor: ExecutorBase, num_gpu_blocks: int) -> bool:
executor.scheduler_config.max_num_batched_tokens = context_length

try:
logger.info(f"Try to determine available gpu blocks for context length {context_length}")
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/engine/llm_engine.py#L477
available_gpu_blocks, _ = executor.determine_num_available_blocks()
except torch.OutOfMemoryError as e:
return False

return available_gpu_blocks >= num_gpu_blocks

def try_set_max_available_seq_len(args: argparse.Namespace):
if args.max_model_len is not None:
logger.info(f"max_model_len is set to {args.max_model_len}, skip probing.")
return

max_probe_steps = 0
if args.kaito_max_probe_steps is not None:
try:
max_probe_steps = int(args.kaito_max_probe_steps)
except ValueError:
raise ValueError("kaito_max_probe_steps must be an integer.")

if max_probe_steps <= 0:
return

engine_args = EngineArgs.from_cli_args(args)
# read the model config from hf weights path.
# vllm will perform different parser for different model architectures
# and read it into a unified EngineConfig.
engine_config = engine_args.create_engine_config()

max_model_len = engine_config.model_config.max_model_len
available_seq_len = max_model_len
logger.info("Try run profiler to find max available seq len")
available_seq_len = find_max_available_seq_len(engine_config, max_probe_steps)
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/worker/worker.py#L262
if available_seq_len <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")

if available_seq_len != max_model_len:
logger.info(f"Set max_model_len from {max_model_len} to {available_seq_len}")
args.max_model_len = available_seq_len
else:
logger.info(f"Using model default max_model_len {max_model_len}")

if __name__ == "__main__":
parser = FlexibleArgumentParser(description='vLLM serving server')
parser = api_server.make_arg_parser(parser)
parser = make_arg_parser(parser)
parser = KAITOArgumentParser(description='KAITO wrapper of vLLM serving server')
args = parser.parse_args()

# set LoRA adapters
if args.lora_modules is None:
args.lora_modules = load_lora_adapters(args.kaito_adapters_dir)

if args.max_model_len is None:
max_probe_steps = 6
if os.getenv("MAX_PROBE_STEPS") is not None:
try:
max_probe_steps = int(os.getenv("MAX_PROBE_STEPS"))
except ValueError:
raise ValueError("MAX_PROBE_STEPS must be an integer.")

engine_args = EngineArgs.from_cli_args(args)
# read the model config from hf weights path.
# vllm will perform different parser for different model architectures
# and read it into a unified EngineConfig.
engine_config = engine_args.create_engine_config()

max_model_len = engine_config.model_config.max_model_len
available_seq_len = max_model_len
if max_probe_steps > 0:
logger.info("Try run profiler to find max available seq len")
available_seq_len = find_max_available_seq_len(engine_config, max_probe_steps)
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/worker/worker.py#L262
if available_seq_len <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")

if available_seq_len != max_model_len:
logger.info(f"Set max_model_len from {max_model_len} to {available_seq_len}")
args.max_model_len = available_seq_len
else:
logger.info(f"Using model default max_model_len {max_model_len}")
try_set_max_available_seq_len(args)

# Run the serving server
logger.info(f"Starting server on port {args.port}")
Expand Down
30 changes: 24 additions & 6 deletions presets/workspace/inference/vllm/tests/test_vllm_inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# Add the parent directory to sys.path
sys.path.append(parent_dir)

from inference_api import binary_search_with_limited_steps
from inference_api import binary_search_with_limited_steps, KaitoConfig
from huggingface_hub import snapshot_download
import shutil

TEST_MODEL = "facebook/opt-125m"
TEST_ADAPTER_NAME1 = "mylora1"
TEST_ADAPTER_NAME2 = "mylora2"
TEST_MODEL_NAME = "mymodel"
TEST_MODEL_LEN = 1024
CHAT_TEMPLATE = ("{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}"
"{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}"
"{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}")
Expand All @@ -33,18 +35,33 @@ def setup_server(request, tmp_path_factory, autouse=True):
global TEST_PORT
TEST_PORT = port

# prepare testing adapter images
tmp_file_dir = tmp_path_factory.mktemp("adapter")
print(f"Downloading adapter image to {tmp_file_dir}")
snapshot_download(repo_id="slall/facebook-opt-125M-imdb-lora", local_dir=str(tmp_file_dir / TEST_ADAPTER_NAME1))
snapshot_download(repo_id="slall/facebook-opt-125M-imdb-lora", local_dir=str(tmp_file_dir / TEST_ADAPTER_NAME2))

# prepare testing config file
config_file = tmp_file_dir / "config.yaml"
kaito_config = KaitoConfig(
vllm={
"max-model-len": TEST_MODEL_LEN,
"served-model-name": TEST_MODEL_NAME
},
max_probe_steps=0,
)
with open(config_file, "w") as f:
f.write(kaito_config.to_yaml())

args = [
"python3",
os.path.join(parent_dir, "inference_api.py"),
"--model", TEST_MODEL,
"--chat-template", CHAT_TEMPLATE,
"--max-model-len", "2048", # expected to be overridden by config file
"--port", str(TEST_PORT),
"--kaito-adapters-dir", tmp_file_dir,
"--kaito-config-file", config_file,
]
print(f">>> Starting server on port {TEST_PORT}")
env = os.environ.copy()
Expand Down Expand Up @@ -90,7 +107,7 @@ def find_available_port(start_port=5000, end_port=8000):

def test_completions_api(setup_server):
request_data = {
"model": TEST_MODEL,
"model": TEST_MODEL_NAME,
"prompt": "Say this is a test",
"max_tokens": 7,
"temperature": 0.5,
Expand All @@ -108,7 +125,7 @@ def test_completions_api(setup_server):

def test_chat_completions_api(setup_server):
request_data = {
"model": TEST_MODEL,
"model": TEST_MODEL_NAME,
"messages": [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there! How can I help you today?"}
Expand All @@ -135,11 +152,12 @@ def test_model_list(setup_server):

assert "data" in data, f"The response should contain a 'data' key, but got {data}"
assert len(data["data"]) == 3, f"The response should contain three models, but got {data['data']}"
assert data["data"][0]["id"] == TEST_MODEL, f"The first model should be the test model, but got {data['data'][0]['id']}"
assert data["data"][0]["id"] == TEST_MODEL_NAME, f"The first model should be the test model, but got {data['data'][0]['id']}"
assert data["data"][0]["max_model_len"] == TEST_MODEL_LEN, f"The first model should have the test model length, but got {data['data'][0]['max_model_len']}"
assert data["data"][1]["id"] == TEST_ADAPTER_NAME2, f"The second model should be the test adapter, but got {data['data'][1]['id']}"
assert data["data"][1]["parent"] == TEST_MODEL, f"The second model should have the test model as parent, but got {data['data'][1]['parent']}"
assert data["data"][1]["parent"] == TEST_MODEL_NAME, f"The second model should have the test model as parent, but got {data['data'][1]['parent']}"
assert data["data"][2]["id"] == TEST_ADAPTER_NAME1, f"The third model should be the test adapter, but got {data['data'][2]['id']}"
assert data["data"][2]["parent"] == TEST_MODEL, f"The third model should have the test model as parent, but got {data['data'][2]['parent']}"
assert data["data"][2]["parent"] == TEST_MODEL_NAME, f"The third model should have the test model as parent, but got {data['data'][2]['parent']}"

def test_binary_search_with_limited_steps():

Expand Down

0 comments on commit 42f9ebc

Please sign in to comment.