Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic adapter support in the grpc server #32

Merged
merged 8 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions proto/generation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,20 @@ enum DecodingMethod {

message BatchedGenerationRequest {
string model_id = 1;
// Deprecated in favor of adapter_id
optional string prefix_id = 2;
repeated GenerationRequest requests = 3;
optional string adapter_id = 4;

Parameters params = 10;
}

message SingleGenerationRequest {
string model_id = 1;
// Deprecated in favor of adapter_id
optional string prefix_id = 2;
GenerationRequest request = 3;
optional string adapter_id = 4;

Parameters params = 10;
}
Expand Down
81 changes: 81 additions & 0 deletions vllm/entrypoints/grpc/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Contains code to map api requests for adapters (e.g. peft prefixes, LoRA)
into valid LLM engine requests"""
import dataclasses
import json
import os
from typing import Dict, Optional, Union

from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
SingleGenerationRequest)
from vllm.entrypoints.grpc.validation import TGISValidationError
from vllm.lora.request import LoRARequest


@dataclasses.dataclass
class AdapterMetadata:
unique_id: int # Unique integer for vllm to identify the adapter
adapter_type: str # The string name of the peft adapter type, e.g. LORA
full_path: str


@dataclasses.dataclass
class AdapterStore:
cache_path: str # Path to local store of adapters to load from
adapters: Dict[str, AdapterMetadata]
next_unique_id: int = 1


def validate_adapters(
request: Union[SingleGenerationRequest, BatchedGenerationRequest],
adapter_store: Optional[AdapterStore]) -> Dict[str, LoRARequest]:
"""Takes the adapter name from the request and constructs a valid
engine request if one is set. Raises if the requested adapter
does not exist or adapter type is unsupported

Returns the kwarg dictionary to add to an engine.generate() call.
"""
adapter_id = request.adapter_id

if adapter_id and not adapter_store:
TGISValidationError.AdaptersDisabled.error()

if not adapter_id or not adapter_store:
return {}

# If not already cached, we need to validate that files exist and
# grab the type out of the adapter_config.json file
if adapter_id not in adapter_store.adapters:
local_adapter_path = os.path.join(adapter_store.cache_path, adapter_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should sanitize the adapter_id here to make sure that the user can't send funny things like ../../../etc/passwd.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!


if not os.path.exists(local_adapter_path):
TGISValidationError.AdapterNotFound.error(
adapter_id, "directory does not exist")

adapter_config_path = os.path.join(local_adapter_path,
"adapter_config.json")
if not os.path.exists(adapter_config_path):
TGISValidationError.AdapterNotFound.error(
adapter_id, "invalid adapter: no adapter_config.json found")

# NB: blocks event loop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will be important to address - to remove the all the file access from the event loop

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I looked into this a bit and it sounds like the asyncio file access in third party libs is... not very good.

I'm not 100% up to speed on event loops, would we want to make a new executor for this sorta like

file_load_executor = ThreadPoolExecutor(max_workers=n)
task = _load_the_config_json_file(...)
await loop.run_in_exeuctor(task, file_load_executor)

or would that just also block the loop?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly .. probably should just make that function be the all the code that's run if we don't find adapter in the dict (i.e. checking on disk, loading it etc).

There's a default asyncio executor that can be used for this kind of thing, or we may want a static one rather than creating one on the fly (not that you were necessarily suggesting that).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, I'll see if I can get that working quickly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill can I get a run from your static analysis on this change?

with open(adapter_config_path) as adapter_config_file:
adapter_config = json.load(adapter_config_file)

adapter_type = adapter_config.get("peft_type", None)

# Add to cache
adapter_store.adapters[adapter_id] = AdapterMetadata(
unique_id=adapter_store.next_unique_id,
adapter_type=adapter_type,
full_path=local_adapter_path)

# Build the proper vllm request object
adapter_metadata = adapter_store.adapters[adapter_id]
if adapter_metadata.adapter_type == "LORA":
lora_request = LoRARequest(lora_name=adapter_id,
lora_int_id=adapter_metadata.unique_id,
lora_local_path=adapter_metadata.full_path)
return {"lora_request": lora_request}

# All other types unsupported
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type)
34 changes: 32 additions & 2 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm import (AsyncLLMEngine, CompletionOutput, RequestOutput,
SamplingParams)
from vllm.config import ModelConfig
from vllm.entrypoints.grpc.adapters import AdapterStore, validate_adapters
from vllm.entrypoints.grpc.pb import generation_pb2_grpc # type: ignore
# yapf: disable
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
Expand All @@ -32,6 +33,7 @@
from vllm.entrypoints.grpc.validation import validate_input, validate_params
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
from vllm.tgis_utils import logs
from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper,
Expand Down Expand Up @@ -116,9 +118,17 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):
self.skip_special_tokens = not args.output_special_tokens
self.default_include_stop_seqs = args.default_include_stop_seqs

self.adapter_store: Optional[AdapterStore] = None
if args.adapter_cache:
self.adapter_store = AdapterStore(
cache_path=args.adapter_cache,
adapters={}
)

async def _post_init(self):
self.config = await self.engine.get_model_config()
self.tokenizer_group = await self.engine.get_tokenizer_group()
# self.tokenizer_group = await self.engine.get_tokenizer_group()
self.tokenizer_group = self.engine.engine.tokenizer
self.tokenizer = await self.engine.get_tokenizer()

# Swap in the special TGIS stats logger
Expand All @@ -144,6 +154,9 @@ async def Generate(self, request: BatchedGenerationRequest,

generators = []
max_is_token_limit = [False] * request_count

adapter_kwargs = await self._validate_adapters(request, context)

for i, req in enumerate(request.requests):
input_ids, max_is_token_limit[i]\
= await self._validate_prompt_and_tokenize(
Expand All @@ -154,7 +167,8 @@ async def Generate(self, request: BatchedGenerationRequest,
self.engine.generate(prompt=req.text,
sampling_params=sampling_params,
request_id=f"{request_id}-{i}",
prompt_token_ids=input_ids),
prompt_token_ids=input_ids,
**adapter_kwargs),
)

# TODO handle cancellation
Expand Down Expand Up @@ -210,13 +224,16 @@ async def GenerateStream(
sampling_params, truncate_input_tokens, request.request.text,
context)

adapter_kwargs, _ = await self._validate_adapters(request, context)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a tuple now right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, totally not. Interestingly python seems totally fine with the unpacking mismatch if you leave an underscore, TIL


result_generator = self.engine.generate(
# prompt is supplied for observability, the text is not
# re-tokenized when `prompt_token_ids` is supplied
prompt=request.request.text,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=input_ids,
**adapter_kwargs
)

resp_options = request.params.response
Expand Down Expand Up @@ -423,6 +440,19 @@ async def _validate_and_convert_params(

return sampling_params, deadline

async def _validate_adapters(self,
request: Union[SingleGenerationRequest,
BatchedGenerationRequest],
context: ServicerContext) \
-> Dict[str, LoRARequest]:
try:
adapters = validate_adapters(
request=request, adapter_store=self.adapter_store)
except ValueError as e:
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
return adapters

@staticmethod
def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
time_limit_reached: bool
Expand Down
6 changes: 5 additions & 1 deletion vllm/entrypoints/grpc/validation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from enum import Enum

from vllm import SamplingParams
Expand Down Expand Up @@ -39,8 +40,11 @@ class TGISValidationError(str, Enum):

# Additions that are _not_ in TGIS
TopN = "top_n_tokens ({0}) must be <= {1}"
AdapterNotFound = "can't retrieve adapter with id '{0}': {1}"
AdaptersDisabled = "adapter_id supplied but no adapter store was configured"
AdapterUnsupported = "adapter type {0} is not currently supported"

def error(self, *args, **kwargs):
def error(self, *args, **kwargs) -> typing.NoReturn:
"""Raises a ValueError with a nicely formatted string"""
raise ValueError(self.value.format(*args, **kwargs))

Expand Down
2 changes: 2 additions & 0 deletions vllm/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--tls-key-path', type=str)
# map to ssl_ca_certs
parser.add_argument('--tls-client-ca-cert-path', type=str)
# add a path when peft adapters will be loaded from
parser.add_argument('--adapter-cache', type=str)

# TODO check/add other args here

Expand Down
8 changes: 6 additions & 2 deletions vllm/tgis_utils/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def log_response(
response=response,
params=request.params,
prefix_id=request.prefix_id,
adapter_id=request.adapter_id,
engine_metrics=engine_metrics,
start_time=start_time,
kind_log=kind_log,
Expand All @@ -57,6 +58,7 @@ def log_error(request: Union[BatchedGenerationRequest,
# of just logging the simple string representation of the error
param_str = text_format.MessageToString(request.params, as_one_line=True)
prefix_id = request.prefix_id
adapter_id = request.adapter_id

if isinstance(request, BatchedGenerationRequest):
method_str = "generate"
Expand All @@ -69,13 +71,14 @@ def log_error(request: Union[BatchedGenerationRequest,
input_chars = sum(len(input_) for input_ in inputs)

span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"input_chars=[{input_chars}] params={param_str}")
f"adapter_id={adapter_id} input_chars=[{input_chars}] "
f"params={param_str}")

logger.error("%s: %s", span_str, exception_str)


def _log_response(inputs: List[str], params: Parameters, prefix_id: str,
response: GenerationResponse,
adapter_id: str, response: GenerationResponse,
engine_metrics: Optional[RequestMetrics], start_time: float,
kind_log: str, method_str: str, logger: logging.Logger):
"""Logs responses similar to how the TGIS server does"""
Expand All @@ -99,6 +102,7 @@ def _log_response(inputs: List[str], params: Parameters, prefix_id: str,

paramstr = text_format.MessageToString(params, as_one_line=True)
span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"adapter_id={adapter_id} "
f"input_chars=[{input_chars}] params={paramstr} "
f"tokenization_time={tokenization_time * 1e3:.2f}ms "
f"queue_time={queue_time * 1e3:.2f}ms "
Expand Down
Loading