Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for worker prompt truncation in ChatML case #3673

Merged
merged 6 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion inference/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def main():
tokenizer = None
else:
tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {tokenizer.vocab_size}")
logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {len(tokenizer)}")

inference_http = utils.HttpClient(
base_url=settings.inference_server_url,
Expand Down
2 changes: 1 addition & 1 deletion inference/worker/basic_hf_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def load_models():
hf_config = transformers.AutoConfig.from_pretrained(model_config.model_id)
logger.warning(f"Loading model {model_config.model_id}...")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
logger.warning(f"tokenizer {tokenizer.name_or_path} has vocab size {tokenizer.vocab_size}")
logger.warning(f"tokenizer {tokenizer.name_or_path} has vocab size {len(tokenizer)}")

# see `decode_token` method, taken from HF text-generation-inference
tokenizer.add_special_tokens({"additional_special_tokens": ["<decode-token>"]})
Expand Down
21 changes: 13 additions & 8 deletions inference/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,14 @@ def get_max_input_length(worker_config: inference.WorkerConfig, plugin_used: boo
return max_input_length


def get_tokens_until(tokens: list[int], target: int | list[int]) -> list[int]:
if isinstance(target, int):
return tokens[: tokens.index(target)]
else:
return next((i for i in range(len(tokens) - len(target) + 1) if tokens[i : i + len(target)] == target))
def get_tokens_until(tokens: list[int], target: list[int]) -> list[int]:
if len(target) == 1:
return tokens[: tokens.index(target[0])]

for i in range(len(tokens) - len(target)):
if tokens[i : i + len(target)] == target:
break
return tokens[:i]


def truncate_prompt(
Expand All @@ -118,8 +121,8 @@ def truncate_prompt(
"""
with shared_tokenizer_lock:
ids = tokenizer.encode(prompt)
# prompter_prefix_ids could be int or list of ints
prompter_prefix_ids = tokenizer.convert_tokens_to_ids(special_tokens["prompter"])
# list of int IDs
prompter_prefix_ids = tokenizer.encode(special_tokens["prompter"])

system_prompt: str | None = None
system_tokens: list[int] | None = None
Expand All @@ -134,7 +137,9 @@ def truncate_prompt(

num_system_tokens = len(system_tokens) if system_tokens else 0
# Maximum token allowed for the conversation, ex system prompt
max_conversation_length = max_input_length - num_system_tokens
# We incorporate a buffer to allow for final inference tokenization differing from ours
# This is a slightly hacky workaround and it would be better to find a cleaner solution
max_conversation_length = max_input_length - num_system_tokens - int(0.01 * max_input_length)
ids = ids[-(max_conversation_length - 1) :]

with shared_tokenizer_lock:
Expand Down
5 changes: 5 additions & 0 deletions oasst-shared/oasst_shared/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,9 @@ def compat_hash(self) -> str:
max_input_length=3072,
max_total_length=4096,
),
"OA_SFT_CodeLlama_13B_10": ModelConfig(
model_id="OpenAssistant/codellama-13b-oasst-sft-v10",
max_input_length=8192,
max_total_length=12288,
),
}