-
-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generic adapter support in the grpc server #32
Changes from 4 commits
0e0f149
8ae710f
a4116d9
b9c7a45
5c1b09a
7ebed78
e2be418
c2bb957
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
"""Contains code to map api requests for adapters (e.g. peft prefixes, LoRA) | ||
into valid LLM engine requests""" | ||
import dataclasses | ||
import json | ||
import os | ||
from typing import Dict, Optional, Union | ||
|
||
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest, | ||
SingleGenerationRequest) | ||
from vllm.entrypoints.grpc.validation import TGISValidationError | ||
from vllm.lora.request import LoRARequest | ||
|
||
|
||
@dataclasses.dataclass | ||
class AdapterMetadata: | ||
unique_id: int # Unique integer for vllm to identify the adapter | ||
adapter_type: str # The string name of the peft adapter type, e.g. LORA | ||
full_path: str | ||
|
||
|
||
@dataclasses.dataclass | ||
class AdapterStore: | ||
cache_path: str # Path to local store of adapters to load from | ||
adapters: Dict[str, AdapterMetadata] | ||
next_unique_id: int = 1 | ||
|
||
|
||
def validate_adapters( | ||
request: Union[SingleGenerationRequest, BatchedGenerationRequest], | ||
adapter_store: Optional[AdapterStore]) -> Dict[str, LoRARequest]: | ||
"""Takes the adapter name from the request and constructs a valid | ||
engine request if one is set. Raises if the requested adapter | ||
does not exist or adapter type is unsupported | ||
|
||
Returns the kwarg dictionary to add to an engine.generate() call. | ||
""" | ||
adapter_id = request.adapter_id | ||
|
||
if adapter_id and not adapter_store: | ||
TGISValidationError.AdaptersDisabled.error() | ||
|
||
if not adapter_id or not adapter_store: | ||
return {} | ||
|
||
# If not already cached, we need to validate that files exist and | ||
# grab the type out of the adapter_config.json file | ||
if adapter_id not in adapter_store.adapters: | ||
local_adapter_path = os.path.join(adapter_store.cache_path, adapter_id) | ||
|
||
if not os.path.exists(local_adapter_path): | ||
TGISValidationError.AdapterNotFound.error( | ||
adapter_id, "directory does not exist") | ||
|
||
adapter_config_path = os.path.join(local_adapter_path, | ||
"adapter_config.json") | ||
if not os.path.exists(adapter_config_path): | ||
TGISValidationError.AdapterNotFound.error( | ||
adapter_id, "invalid adapter: no adapter_config.json found") | ||
|
||
# NB: blocks event loop | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this will be important to address - to remove the all the file access from the event loop There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I looked into this a bit and it sounds like the asyncio file access in third party libs is... not very good. I'm not 100% up to speed on event loops, would we want to make a new executor for this sorta like
or would that just also block the loop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah exactly .. probably should just make that function be the all the code that's run if we don't find adapter in the dict (i.e. checking on disk, loading it etc). There's a default asyncio executor that can be used for this kind of thing, or we may want a static one rather than creating one on the fly (not that you were necessarily suggesting that). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool, I'll see if I can get that working quickly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @njhill can I get a run from your static analysis on this change? |
||
with open(adapter_config_path) as adapter_config_file: | ||
adapter_config = json.load(adapter_config_file) | ||
|
||
adapter_type = adapter_config.get("peft_type", None) | ||
|
||
# Add to cache | ||
adapter_store.adapters[adapter_id] = AdapterMetadata( | ||
unique_id=adapter_store.next_unique_id, | ||
adapter_type=adapter_type, | ||
full_path=local_adapter_path) | ||
|
||
# Build the proper vllm request object | ||
adapter_metadata = adapter_store.adapters[adapter_id] | ||
if adapter_metadata.adapter_type == "LORA": | ||
lora_request = LoRARequest(lora_name=adapter_id, | ||
lora_int_id=adapter_metadata.unique_id, | ||
lora_local_path=adapter_metadata.full_path) | ||
return {"lora_request": lora_request} | ||
|
||
# All other types unsupported | ||
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
from vllm import (AsyncLLMEngine, CompletionOutput, RequestOutput, | ||
SamplingParams) | ||
from vllm.config import ModelConfig | ||
from vllm.entrypoints.grpc.adapters import AdapterStore, validate_adapters | ||
from vllm.entrypoints.grpc.pb import generation_pb2_grpc # type: ignore | ||
# yapf: disable | ||
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest, | ||
|
@@ -32,6 +33,7 @@ | |
from vllm.entrypoints.grpc.validation import validate_input, validate_params | ||
from vllm.entrypoints.openai.serving_completion import merge_async_iterators | ||
from vllm.logger import init_logger | ||
from vllm.lora.request import LoRARequest | ||
from vllm.sequence import Logprob | ||
from vllm.tgis_utils import logs | ||
from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper, | ||
|
@@ -116,9 +118,17 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace): | |
self.skip_special_tokens = not args.output_special_tokens | ||
self.default_include_stop_seqs = args.default_include_stop_seqs | ||
|
||
self.adapter_store: Optional[AdapterStore] = None | ||
if args.adapter_cache: | ||
self.adapter_store = AdapterStore( | ||
cache_path=args.adapter_cache, | ||
adapters={} | ||
) | ||
|
||
async def _post_init(self): | ||
self.config = await self.engine.get_model_config() | ||
self.tokenizer_group = await self.engine.get_tokenizer_group() | ||
# self.tokenizer_group = await self.engine.get_tokenizer_group() | ||
joerunde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.tokenizer_group = self.engine.engine.tokenizer | ||
joerunde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.tokenizer = await self.engine.get_tokenizer() | ||
|
||
# Swap in the special TGIS stats logger | ||
|
@@ -144,6 +154,9 @@ async def Generate(self, request: BatchedGenerationRequest, | |
|
||
generators = [] | ||
max_is_token_limit = [False] * request_count | ||
|
||
adapter_kwargs = await self._validate_adapters(request, context) | ||
|
||
for i, req in enumerate(request.requests): | ||
input_ids, max_is_token_limit[i]\ | ||
= await self._validate_prompt_and_tokenize( | ||
|
@@ -154,7 +167,8 @@ async def Generate(self, request: BatchedGenerationRequest, | |
self.engine.generate(prompt=req.text, | ||
sampling_params=sampling_params, | ||
request_id=f"{request_id}-{i}", | ||
prompt_token_ids=input_ids), | ||
prompt_token_ids=input_ids, | ||
**adapter_kwargs), | ||
) | ||
|
||
# TODO handle cancellation | ||
|
@@ -210,13 +224,16 @@ async def GenerateStream( | |
sampling_params, truncate_input_tokens, request.request.text, | ||
context) | ||
|
||
adapter_kwargs, _ = await self._validate_adapters(request, context) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a tuple now right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh yeah, totally not. Interestingly python seems totally fine with the unpacking mismatch if you leave an underscore, TIL |
||
|
||
result_generator = self.engine.generate( | ||
# prompt is supplied for observability, the text is not | ||
# re-tokenized when `prompt_token_ids` is supplied | ||
prompt=request.request.text, | ||
sampling_params=sampling_params, | ||
request_id=request_id, | ||
prompt_token_ids=input_ids, | ||
**adapter_kwargs | ||
) | ||
|
||
resp_options = request.params.response | ||
|
@@ -423,6 +440,19 @@ async def _validate_and_convert_params( | |
|
||
return sampling_params, deadline | ||
|
||
async def _validate_adapters(self, | ||
request: Union[SingleGenerationRequest, | ||
BatchedGenerationRequest], | ||
context: ServicerContext) \ | ||
-> Dict[str, LoRARequest]: | ||
try: | ||
adapters = validate_adapters( | ||
request=request, adapter_store=self.adapter_store) | ||
except ValueError as e: | ||
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION) | ||
await context.abort(StatusCode.INVALID_ARGUMENT, str(e)) | ||
return adapters | ||
|
||
@staticmethod | ||
def _convert_reason(output: CompletionOutput, max_is_token_limit: bool, | ||
time_limit_reached: bool | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should sanitize the
adapter_id
here to make sure that the user can't send funny things like../../../etc/passwd
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!