Skip to content

Commit

Permalink
Generic adapter support in the grpc server (#32)
Browse files Browse the repository at this point in the history
Adds support for multi-lora adapters.

Passing tests added over in this PR:
https://github.ibm.com/ai-foundation/tgis-deploy-tests/pull/25/files

---------

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
  • Loading branch information
joerunde authored Jun 11, 2024
1 parent 0fe7794 commit 79b7364
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 5 deletions.
4 changes: 4 additions & 0 deletions proto/generation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ enum DecodingMethod {

message BatchedGenerationRequest {
string model_id = 1;
// Deprecated in favor of adapter_id
optional string prefix_id = 2;
optional string adapter_id = 4;
repeated GenerationRequest requests = 3;

Parameters params = 10;
}

message SingleGenerationRequest {
string model_id = 1;
// Deprecated in favor of adapter_id
optional string prefix_id = 2;
optional string adapter_id = 4;
GenerationRequest request = 3;

Parameters params = 10;
Expand Down
118 changes: 118 additions & 0 deletions vllm/entrypoints/grpc/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Contains code to map api requests for adapters (e.g. peft prefixes, LoRA)
into valid LLM engine requests"""
import asyncio
import concurrent.futures
import dataclasses
import json
import os
import re
from pathlib import Path
from typing import Dict, Optional, Union

from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
SingleGenerationRequest)
from vllm.entrypoints.grpc.validation import TGISValidationError
from vllm.lora.request import LoRARequest

global_thread_pool = None # used for loading adapter files from disk

VALID_ADAPTER_ID_PATTERN = re.compile("[/\\w\\-]+")


@dataclasses.dataclass
class AdapterMetadata:
unique_id: int # Unique integer for vllm to identify the adapter
adapter_type: str # The string name of the peft adapter type, e.g. LORA
full_path: str


@dataclasses.dataclass
class AdapterStore:
cache_path: str # Path to local store of adapters to load from
adapters: Dict[str, AdapterMetadata]
next_unique_id: int = 1


async def validate_adapters(
request: Union[SingleGenerationRequest, BatchedGenerationRequest],
adapter_store: Optional[AdapterStore]) -> Dict[str, LoRARequest]:
"""Takes the adapter name from the request and constructs a valid
engine request if one is set. Raises if the requested adapter
does not exist or adapter type is unsupported
Returns the kwarg dictionary to add to an engine.generate() call.
"""
global global_thread_pool
adapter_id = request.adapter_id

if adapter_id and not adapter_store:
TGISValidationError.AdaptersDisabled.error()

if not adapter_id or not adapter_store:
return {}

# If not already cached, we need to validate that files exist and
# grab the type out of the adapter_config.json file
if (adapter_metadata := adapter_store.adapters.get(adapter_id)) is None:
_reject_bad_adapter_id(adapter_id)
local_adapter_path = os.path.join(adapter_store.cache_path, adapter_id)

loop = asyncio.get_running_loop()
if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2)

adapter_type = await loop.run_in_executor(global_thread_pool,
_get_adapter_type_from_file,
adapter_id,
local_adapter_path)

# Add to cache
adapter_metadata = AdapterMetadata(
unique_id=adapter_store.next_unique_id,
adapter_type=adapter_type,
full_path=local_adapter_path)
adapter_store.adapters[adapter_id] = adapter_metadata

# Build the proper vllm request object
if adapter_metadata.adapter_type == "LORA":
lora_request = LoRARequest(lora_name=adapter_id,
lora_int_id=adapter_metadata.unique_id,
lora_local_path=adapter_metadata.full_path)
return {"lora_request": lora_request}

# All other types unsupported
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type)


def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
"""This function does all the filesystem access required to deduce the type
of the adapter. It's run in a separate thread pool executor so that file
access does not block the main event loop."""
if not os.path.exists(adapter_path):
TGISValidationError.AdapterNotFound.error(adapter_id,
"directory does not exist")

adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
if not os.path.exists(adapter_config_path):
TGISValidationError.AdapterNotFound.error(
adapter_id, "invalid adapter: no adapter_config.json found")

# NB: blocks event loop
with open(adapter_config_path) as adapter_config_file:
adapter_config = json.load(adapter_config_file)

return adapter_config.get("peft_type", None)


def _reject_bad_adapter_id(adapter_id: str) -> None:
"""Raise if the adapter id attempts path traversal or has invalid file path
characters"""
if not VALID_ADAPTER_ID_PATTERN.fullmatch(adapter_id):
TGISValidationError.InvalidAdapterID.error(adapter_id)

# Check for path traversal
root_path = Path("/some/file/root")
derived_path = root_path / adapter_id
if not os.path.normpath(derived_path).startswith(str(root_path) + "/"):
TGISValidationError.InvalidAdapterID.error(adapter_id)
32 changes: 30 additions & 2 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm import (AsyncLLMEngine, CompletionOutput, RequestOutput,
SamplingParams)
from vllm.config import ModelConfig
from vllm.entrypoints.grpc.adapters import AdapterStore, validate_adapters
from vllm.entrypoints.grpc.pb import generation_pb2_grpc # type: ignore
# yapf: disable
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
Expand All @@ -33,6 +34,7 @@
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.inputs import TextTokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
from vllm.tgis_utils import logs
from vllm.tgis_utils.guided_decoding import (
Expand Down Expand Up @@ -119,6 +121,13 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):
self.skip_special_tokens = not args.output_special_tokens
self.default_include_stop_seqs = args.default_include_stop_seqs

self.adapter_store: Optional[AdapterStore] = None
if args.adapter_cache:
self.adapter_store = AdapterStore(
cache_path=args.adapter_cache,
adapters={}
)

async def _post_init(self):
self.config = await self.engine.get_model_config()
# self.tokenizer_group = await self.engine.get_tokenizer_group()
Expand Down Expand Up @@ -148,6 +157,9 @@ async def Generate(self, request: BatchedGenerationRequest,

generators = []
max_is_token_limit = [False] * request_count

adapter_kwargs = await self._validate_adapters(request, context)

for i, req in enumerate(request.requests):
input_ids, max_is_token_limit[i]\
= await self._validate_prompt_and_tokenize(
Expand All @@ -161,7 +173,8 @@ async def Generate(self, request: BatchedGenerationRequest,
# re-tokenized when `prompt_token_ids` is supplied
self.engine.generate(inputs=inputs,
sampling_params=sampling_params,
request_id=f"{request_id}-{i}"),
request_id=f"{request_id}-{i}",
**adapter_kwargs),
)

# TODO handle cancellation
Expand Down Expand Up @@ -218,6 +231,7 @@ async def GenerateStream(
sampling_params, truncate_input_tokens, request.request.text,
context)

adapter_kwargs = await self._validate_adapters(request, context)
inputs = TextTokensPrompt(
prompt=request.request.text,
prompt_token_ids=input_ids
Expand All @@ -228,7 +242,8 @@ async def GenerateStream(
# re-tokenized when `prompt_token_ids` is supplied
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id
request_id=request_id,
**adapter_kwargs
)

resp_options = request.params.response
Expand Down Expand Up @@ -442,6 +457,19 @@ async def _validate_and_convert_params(

return sampling_params, deadline

async def _validate_adapters(self,
request: Union[SingleGenerationRequest,
BatchedGenerationRequest],
context: ServicerContext) \
-> Dict[str, LoRARequest]:
try:
adapters = await validate_adapters(
request=request, adapter_store=self.adapter_store)
except ValueError as e:
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
return adapters

@staticmethod
def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
time_limit_reached: bool
Expand Down
8 changes: 7 additions & 1 deletion vllm/entrypoints/grpc/validation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from enum import Enum

from vllm import SamplingParams
Expand Down Expand Up @@ -39,8 +40,13 @@ class TGISValidationError(str, Enum):

# Additions that are _not_ in TGIS
TopN = "top_n_tokens ({0}) must be <= {1}"
AdapterNotFound = "can't retrieve adapter with id '{0}': {1}"
AdaptersDisabled = "adapter_id supplied but no adapter store was configured"
AdapterUnsupported = "adapter type {0} is not currently supported"
InvalidAdapterID = ("Invalid adapter id '{0}', must contain only "
"alphanumeric, _ and - and /")

def error(self, *args, **kwargs):
def error(self, *args, **kwargs) -> typing.NoReturn:
"""Raises a ValueError with a nicely formatted string"""
raise ValueError(self.value.format(*args, **kwargs))

Expand Down
2 changes: 2 additions & 0 deletions vllm/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--tls-key-path', type=str)
# map to ssl_ca_certs
parser.add_argument('--tls-client-ca-cert-path', type=str)
# add a path when peft adapters will be loaded from
parser.add_argument('--adapter-cache', type=str)

# TODO check/add other args here

Expand Down
8 changes: 6 additions & 2 deletions vllm/tgis_utils/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def log_response(
response=response,
params=request.params,
prefix_id=request.prefix_id,
adapter_id=request.adapter_id,
engine_metrics=engine_metrics,
start_time=start_time,
kind_log=kind_log,
Expand All @@ -57,6 +58,7 @@ def log_error(request: Union[BatchedGenerationRequest,
# of just logging the simple string representation of the error
param_str = text_format.MessageToString(request.params, as_one_line=True)
prefix_id = request.prefix_id
adapter_id = request.adapter_id

if isinstance(request, BatchedGenerationRequest):
method_str = "generate"
Expand All @@ -69,13 +71,14 @@ def log_error(request: Union[BatchedGenerationRequest,
input_chars = sum(len(input_) for input_ in inputs)

span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"input_chars=[{input_chars}] params={param_str}")
f"adapter_id={adapter_id} input_chars=[{input_chars}] "
f"params={param_str}")

logger.error("%s: %s", span_str, exception_str)


def _log_response(inputs: List[str], params: Parameters, prefix_id: str,
response: GenerationResponse,
adapter_id: str, response: GenerationResponse,
engine_metrics: Optional[RequestMetrics], start_time: float,
kind_log: str, method_str: str, logger: logging.Logger):
"""Logs responses similar to how the TGIS server does"""
Expand All @@ -99,6 +102,7 @@ def _log_response(inputs: List[str], params: Parameters, prefix_id: str,

paramstr = text_format.MessageToString(params, as_one_line=True)
span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"adapter_id={adapter_id} "
f"input_chars=[{input_chars}] params={paramstr} "
f"tokenization_time={tokenization_time * 1e3:.2f}ms "
f"queue_time={queue_time * 1e3:.2f}ms "
Expand Down

0 comments on commit 79b7364

Please sign in to comment.