Skip to content

Commit

Permalink
[RFC][vllm-API] Support tokenizer registry for customized tokenizer i…
Browse files Browse the repository at this point in the history
…n vLLM (#12518)

Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
  • Loading branch information
youngkent authored Feb 12, 2025
1 parent 72c2b68 commit 3ee696a
Show file tree
Hide file tree
Showing 11 changed files with 343 additions and 41 deletions.
5 changes: 3 additions & 2 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,11 +1275,12 @@ def main(args: argparse.Namespace):
'--tokenizer-mode',
type=str,
default="auto",
choices=['auto', 'slow', 'mistral'],
choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
'"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.')

parser.add_argument("--served-model-name",
type=str,
Expand Down
123 changes: 123 additions & 0 deletions tests/tokenization/test_tokenizer_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
TokenizerRegistry)

if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam


class TestTokenizer(TokenizerBase):

@classmethod
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
return TestTokenizer()

@property
def all_special_tokens_extended(self) -> List[str]:
raise NotImplementedError()

@property
def all_special_tokens(self) -> List[str]:
raise NotImplementedError()

@property
def all_special_ids(self) -> List[int]:
raise NotImplementedError()

@property
def bos_token_id(self) -> int:
return 0

@property
def eos_token_id(self) -> int:
return 1

@property
def sep_token(self) -> str:
raise NotImplementedError()

@property
def pad_token(self) -> str:
raise NotImplementedError()

@property
def is_fast(self) -> bool:
raise NotImplementedError()

@property
def vocab_size(self) -> int:
raise NotImplementedError()

@property
def max_token_id(self) -> int:
raise NotImplementedError()

def __call__(
self,
text: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
raise NotImplementedError()

def get_vocab(self) -> Dict[str, int]:
raise NotImplementedError()

def get_added_vocab(self) -> Dict[str, int]:
raise NotImplementedError()

def encode_one(
self,
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
raise NotImplementedError()

def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
raise NotImplementedError()

def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs) -> List[int]:
raise NotImplementedError()

def convert_tokens_to_string(self, tokens: List[str]) -> str:
raise NotImplementedError()

def decode(self,
ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
raise NotImplementedError()

def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
raise NotImplementedError()


def test_customized_tokenizer():
TokenizerRegistry.register("test_tokenizer",
"tests.tokenization.test_tokenizer_registry",
"TestTokenizer")

tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1

tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
9 changes: 5 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ class ModelConfig:
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
"mistral" will always use the tokenizer from `mistral_common`.
available, "slow" will always use the slow tokenizer,
"mistral" will always use the tokenizer from `mistral_common`, and
"custom" will use --tokenizer to select the preregistered tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
Expand Down Expand Up @@ -467,10 +468,10 @@ def _init_has_inner_state(self) -> bool:

def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow", "mistral"]:
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto', 'slow' or 'mistral'.")
"either 'auto', 'slow', 'mistral' or 'custom'.")
self.tokenizer_mode = tokenizer_mode

def _get_preferred_task(
Expand Down
6 changes: 4 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow', 'mistral'],
choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
'"mistral" will always use the `mistral_common` tokenizer. \n* '
'"custom" will use --tokenizer to select the '
'preregistered tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
Expand Down
31 changes: 19 additions & 12 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,9 +1051,9 @@ def _embedding_score(

def _cross_encoding_score(
self,
tokenizer: Union[AnyTokenizer],
text_1: List[Union[str, TextPrompt, TokensPrompt]],
text_2: List[Union[str, TextPrompt, TokensPrompt]],
tokenizer: AnyTokenizer,
text_1: List[str],
text_2: List[str],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
Expand Down Expand Up @@ -1176,29 +1176,36 @@ def ensure_str(prompt: SingletonPrompt):
if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list.
text_1 = [text_1]
text_1 = [ensure_str(t) for t in text_1]
input_text_1: List[str] = [ensure_str(t) for t in text_1]

if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list.
text_2 = [text_2]
text_2 = [ensure_str(t) for t in text_2]
input_text_2: List[str] = [ensure_str(t) for t in text_2]

if len(text_1) > 1 and len(text_1) != len(text_2):
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(text_1) == 0:
if len(input_text_1) == 0:
raise ValueError("At least one text element must be given")
if len(text_2) == 0:
if len(input_text_2) == 0:
raise ValueError("At least one text_pair element must be given")

if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, text_1, text_2,
return self._cross_encoding_score(tokenizer, input_text_1,
input_text_2,
truncate_prompt_tokens, use_tqdm,
lora_request,
prompt_adapter_request)
else:
return self._embedding_score(tokenizer, text_1, text_2,
truncate_prompt_tokens, use_tqdm,
lora_request, prompt_adapter_request)

return self._embedding_score(
tokenizer,
input_text_1, # type: ignore[arg-type]
input_text_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
lora_request,
prompt_adapter_request)

def start_profile(self) -> None:
self.llm_engine.start_profile()
Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,7 @@ async def _preprocess_chat(
_chat_template_kwargs.update(chat_template_kwargs or {})

request_prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
if isinstance(tokenizer, MistralTokenizer):
request_prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def create_score(

tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q,
prompt_inputs = await tokenize_async(q,
text_pair=t,
**tokenization_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion vllm/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_bad_words_logits_processors(

if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(prompt=prompt)
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
Expand Down
18 changes: 12 additions & 6 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
TokenizerRegistry)
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async

logger = init_logger(__name__)

AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer]
TokenizerBase]


def decode_tokens(
Expand All @@ -47,11 +49,7 @@ def encode_tokens(
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
elif add_special_tokens is not None:
if add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text)

Expand Down Expand Up @@ -183,9 +181,17 @@ def get_tokenizer(
'encoding and decoding.',
FutureWarning,
stacklevel=2)

tokenizer: AnyTokenizer
if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
revision=revision)
elif tokenizer_mode == "custom":
tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
*args,
revision=revision,
download_dir=download_dir,
**kwargs)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(
Expand Down
Loading

0 comments on commit 3ee696a

Please sign in to comment.