Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic adapter support in the grpc server #32

Merged
merged 8 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions proto/generation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ enum DecodingMethod {

message BatchedGenerationRequest {
string model_id = 1;
// Deprecated in favor of adapter_id
optional string prefix_id = 2;
optional string adapter_id = 4;
repeated GenerationRequest requests = 3;

Parameters params = 10;
}

message SingleGenerationRequest {
string model_id = 1;
// Deprecated in favor of adapter_id
optional string prefix_id = 2;
optional string adapter_id = 4;
GenerationRequest request = 3;

Parameters params = 10;
Expand Down
118 changes: 118 additions & 0 deletions vllm/entrypoints/grpc/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Contains code to map api requests for adapters (e.g. peft prefixes, LoRA)
into valid LLM engine requests"""
import asyncio
import concurrent.futures
import dataclasses
import json
import os
import re
from pathlib import Path
from typing import Dict, Optional, Union

from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
SingleGenerationRequest)
from vllm.entrypoints.grpc.validation import TGISValidationError
from vllm.lora.request import LoRARequest

global_thread_pool = None # used for loading adapter files from disk

VALID_ADAPTER_ID_PATTERN = re.compile("[/\\w\\-]+")


@dataclasses.dataclass
class AdapterMetadata:
unique_id: int # Unique integer for vllm to identify the adapter
adapter_type: str # The string name of the peft adapter type, e.g. LORA
full_path: str


@dataclasses.dataclass
class AdapterStore:
cache_path: str # Path to local store of adapters to load from
adapters: Dict[str, AdapterMetadata]
next_unique_id: int = 1


async def validate_adapters(
request: Union[SingleGenerationRequest, BatchedGenerationRequest],
adapter_store: Optional[AdapterStore]) -> Dict[str, LoRARequest]:
"""Takes the adapter name from the request and constructs a valid
engine request if one is set. Raises if the requested adapter
does not exist or adapter type is unsupported

Returns the kwarg dictionary to add to an engine.generate() call.
"""
global global_thread_pool
adapter_id = request.adapter_id

if adapter_id and not adapter_store:
TGISValidationError.AdaptersDisabled.error()

if not adapter_id or not adapter_store:
return {}

# If not already cached, we need to validate that files exist and
# grab the type out of the adapter_config.json file
if (adapter_metadata := adapter_store.adapters.get(adapter_id)) is None:
_reject_bad_adapter_id(adapter_id)
local_adapter_path = os.path.join(adapter_store.cache_path, adapter_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should sanitize the adapter_id here to make sure that the user can't send funny things like ../../../etc/passwd.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!


loop = asyncio.get_running_loop()
if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2)

adapter_type = await loop.run_in_executor(global_thread_pool,
_get_adapter_type_from_file,
adapter_id,
local_adapter_path)

# Add to cache
adapter_metadata = AdapterMetadata(
unique_id=adapter_store.next_unique_id,
adapter_type=adapter_type,
full_path=local_adapter_path)
adapter_store.adapters[adapter_id] = adapter_metadata

# Build the proper vllm request object
if adapter_metadata.adapter_type == "LORA":
lora_request = LoRARequest(lora_name=adapter_id,
lora_int_id=adapter_metadata.unique_id,
lora_local_path=adapter_metadata.full_path)
return {"lora_request": lora_request}

# All other types unsupported
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type)


def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
"""This function does all the filesystem access required to deduce the type
of the adapter. It's run in a separate thread pool executor so that file
access does not block the main event loop."""
if not os.path.exists(adapter_path):
TGISValidationError.AdapterNotFound.error(adapter_id,
"directory does not exist")

adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
if not os.path.exists(adapter_config_path):
TGISValidationError.AdapterNotFound.error(
adapter_id, "invalid adapter: no adapter_config.json found")

# NB: blocks event loop
with open(adapter_config_path) as adapter_config_file:
adapter_config = json.load(adapter_config_file)

return adapter_config.get("peft_type", None)


def _reject_bad_adapter_id(adapter_id: str) -> None:
"""Raise if the adapter id attempts path traversal or has invalid file path
characters"""
if not VALID_ADAPTER_ID_PATTERN.fullmatch(adapter_id):
TGISValidationError.InvalidAdapterID.error(adapter_id)

# Check for path traversal
root_path = Path("/some/file/root")
derived_path = root_path / adapter_id
if not os.path.normpath(derived_path).startswith(str(root_path) + "/"):
TGISValidationError.InvalidAdapterID.error(adapter_id)
32 changes: 30 additions & 2 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm import (AsyncLLMEngine, CompletionOutput, RequestOutput,
SamplingParams)
from vllm.config import ModelConfig
from vllm.entrypoints.grpc.adapters import AdapterStore, validate_adapters
from vllm.entrypoints.grpc.pb import generation_pb2_grpc # type: ignore
# yapf: disable
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
Expand All @@ -33,6 +34,7 @@
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.inputs import TextTokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
from vllm.tgis_utils import logs
from vllm.tgis_utils.guided_decoding import (
Expand Down Expand Up @@ -119,6 +121,13 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):
self.skip_special_tokens = not args.output_special_tokens
self.default_include_stop_seqs = args.default_include_stop_seqs

self.adapter_store: Optional[AdapterStore] = None
if args.adapter_cache:
self.adapter_store = AdapterStore(
cache_path=args.adapter_cache,
adapters={}
)

async def _post_init(self):
self.config = await self.engine.get_model_config()
# self.tokenizer_group = await self.engine.get_tokenizer_group()
Expand Down Expand Up @@ -148,6 +157,9 @@ async def Generate(self, request: BatchedGenerationRequest,

generators = []
max_is_token_limit = [False] * request_count

adapter_kwargs = await self._validate_adapters(request, context)

for i, req in enumerate(request.requests):
input_ids, max_is_token_limit[i]\
= await self._validate_prompt_and_tokenize(
Expand All @@ -161,7 +173,8 @@ async def Generate(self, request: BatchedGenerationRequest,
# re-tokenized when `prompt_token_ids` is supplied
self.engine.generate(inputs=inputs,
sampling_params=sampling_params,
request_id=f"{request_id}-{i}"),
request_id=f"{request_id}-{i}",
**adapter_kwargs),
)

# TODO handle cancellation
Expand Down Expand Up @@ -218,6 +231,7 @@ async def GenerateStream(
sampling_params, truncate_input_tokens, request.request.text,
context)

adapter_kwargs = await self._validate_adapters(request, context)
inputs = TextTokensPrompt(
prompt=request.request.text,
prompt_token_ids=input_ids
Expand All @@ -228,7 +242,8 @@ async def GenerateStream(
# re-tokenized when `prompt_token_ids` is supplied
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id
request_id=request_id,
**adapter_kwargs
)

resp_options = request.params.response
Expand Down Expand Up @@ -442,6 +457,19 @@ async def _validate_and_convert_params(

return sampling_params, deadline

async def _validate_adapters(self,
request: Union[SingleGenerationRequest,
BatchedGenerationRequest],
context: ServicerContext) \
-> Dict[str, LoRARequest]:
try:
adapters = await validate_adapters(
request=request, adapter_store=self.adapter_store)
except ValueError as e:
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
return adapters

@staticmethod
def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
time_limit_reached: bool
Expand Down
8 changes: 7 additions & 1 deletion vllm/entrypoints/grpc/validation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from enum import Enum

from vllm import SamplingParams
Expand Down Expand Up @@ -39,8 +40,13 @@ class TGISValidationError(str, Enum):

# Additions that are _not_ in TGIS
TopN = "top_n_tokens ({0}) must be <= {1}"
AdapterNotFound = "can't retrieve adapter with id '{0}': {1}"
AdaptersDisabled = "adapter_id supplied but no adapter store was configured"
AdapterUnsupported = "adapter type {0} is not currently supported"
InvalidAdapterID = ("Invalid adapter id '{0}', must contain only "
"alphanumeric, _ and - and /")

def error(self, *args, **kwargs):
def error(self, *args, **kwargs) -> typing.NoReturn:
"""Raises a ValueError with a nicely formatted string"""
raise ValueError(self.value.format(*args, **kwargs))

Expand Down
2 changes: 2 additions & 0 deletions vllm/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--tls-key-path', type=str)
# map to ssl_ca_certs
parser.add_argument('--tls-client-ca-cert-path', type=str)
# add a path when peft adapters will be loaded from
parser.add_argument('--adapter-cache', type=str)

# TODO check/add other args here

Expand Down
8 changes: 6 additions & 2 deletions vllm/tgis_utils/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def log_response(
response=response,
params=request.params,
prefix_id=request.prefix_id,
adapter_id=request.adapter_id,
engine_metrics=engine_metrics,
start_time=start_time,
kind_log=kind_log,
Expand All @@ -57,6 +58,7 @@ def log_error(request: Union[BatchedGenerationRequest,
# of just logging the simple string representation of the error
param_str = text_format.MessageToString(request.params, as_one_line=True)
prefix_id = request.prefix_id
adapter_id = request.adapter_id

if isinstance(request, BatchedGenerationRequest):
method_str = "generate"
Expand All @@ -69,13 +71,14 @@ def log_error(request: Union[BatchedGenerationRequest,
input_chars = sum(len(input_) for input_ in inputs)

span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"input_chars=[{input_chars}] params={param_str}")
f"adapter_id={adapter_id} input_chars=[{input_chars}] "
f"params={param_str}")

logger.error("%s: %s", span_str, exception_str)


def _log_response(inputs: List[str], params: Parameters, prefix_id: str,
response: GenerationResponse,
adapter_id: str, response: GenerationResponse,
engine_metrics: Optional[RequestMetrics], start_time: float,
kind_log: str, method_str: str, logger: logging.Logger):
"""Logs responses similar to how the TGIS server does"""
Expand All @@ -99,6 +102,7 @@ def _log_response(inputs: List[str], params: Parameters, prefix_id: str,

paramstr = text_format.MessageToString(params, as_one_line=True)
span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"adapter_id={adapter_id} "
f"input_chars=[{input_chars}] params={paramstr} "
f"tokenization_time={tokenization_time * 1e3:.2f}ms "
f"queue_time={queue_time * 1e3:.2f}ms "
Expand Down
Loading