Skip to content

Commit

Permalink
Merge pull request #787 from duc-phamh/improve_message_trimming
Browse files Browse the repository at this point in the history
Improve message trimming
  • Loading branch information
ishaan-jaff authored Nov 11, 2023
2 parents 0cee50f + 8e13da1 commit fd6064b
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 14 deletions.
36 changes: 32 additions & 4 deletions litellm/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys, os
import traceback
from dotenv import load_dotenv
import copy

load_dotenv()
import os
Expand Down Expand Up @@ -56,14 +56,42 @@ def test_multiple_messages_no_trimming():
# test_multiple_messages_no_trimming()


def test_large_trimming():
def test_large_trimming_multiple_messages():
messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}, {"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."},{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}]
trimmed_messages = trim_messages(messages, max_tokens=20, model="random")
trimmed_messages = trim_messages(messages, max_tokens=20, model="gpt-4-0613")
print("trimmed messages")
print(trimmed_messages)
assert(get_token_count(messages=trimmed_messages, model="random")) <= 20
assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 20
# test_large_trimming()

def test_large_trimming_single_message():
messages = [{"role": "user", "content": "This is a singlelongwordthatexceedsthelimit."}]
trimmed_messages = trim_messages(messages, max_tokens=5, model="gpt-4-0613")
assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) <= 5
assert(get_token_count(messages=trimmed_messages, model="gpt-4-0613")) > 0


def test_trimming_with_system_message_within_max_tokens():
# This message is 33 tokens long
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
trimmed_messages = trim_messages(messages, max_tokens=30, model="gpt-4-0613") # The system message should fit within the token limit
assert len(trimmed_messages) == 2
assert trimmed_messages[0]["content"] == "This is a short system message"


def test_trimming_with_system_message_exceeding_max_tokens():
# This message is 33 tokens long. The system message is 13 tokens long.
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613")
assert len(trimmed_messages) == 1
assert '..' in trimmed_messages[0]["content"]

def test_trimming_should_not_change_original_messages():
messages = [{"role": "system", "content": "This is a short system message"}, {"role": "user", "content": "This is a medium normal message, let's say litellm is awesome."}]
messages_copy = copy.deepcopy(messages)
trimmed_messages = trim_messages(messages, max_tokens=12, model="gpt-4-0613")
assert(messages==messages_copy)

def test_get_valid_models():
old_environ = os.environ
os.environ = {'OPENAI_API_KEY': 'temp'} # mock set only openai key in environ
Expand Down
90 changes: 80 additions & 10 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import aiohttp
import logging
import asyncio
import copy
from tokenizers import Tokenizer
from dataclasses import (
dataclass,
Expand Down Expand Up @@ -1111,6 +1112,50 @@ def decode(model: str, tokens: List[int]):
dec = tokenizer_json["tokenizer"].decode(tokens)
return dec

def openai_token_counter(messages, model="gpt-3.5-turbo-0613"):
"""
Return the number of tokens used by a list of messages.
Borrowed from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb.
"""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return openai_token_counter(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return openai_token_counter(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens

def token_counter(model="", text=None, messages: Optional[List] = None):
"""
Count the number of tokens in a given text using a specified model.
Expand All @@ -1131,14 +1176,17 @@ def token_counter(model="", text=None, messages: Optional[List] = None):
raise ValueError("text and messages cannot both be None")
num_tokens = 0

if model is not None:
if model is not None:
tokenizer_json = _select_tokenizer(model=model)
if tokenizer_json["type"] == "huggingface_tokenizer":
enc = tokenizer_json["tokenizer"].encode(text)
num_tokens = len(enc.ids)
elif tokenizer_json["type"] == "openai_tokenizer":
enc = tokenizer_json["tokenizer"].encode(text)
num_tokens = len(enc)
if messages is not None:
num_tokens = openai_token_counter(messages, model=model)
else:
enc = tokenizer_json["tokenizer"].encode(text)
num_tokens = len(enc)
else:
num_tokens = len(encoding.encode(text))
return num_tokens
Expand Down Expand Up @@ -4574,13 +4622,13 @@ def completion_with_fallbacks(**kwargs):

def process_system_message(system_message, max_tokens, model):
system_message_event = {"role": "system", "content": system_message}
system_message_tokens = get_token_count(system_message_event, model)
system_message_tokens = get_token_count([system_message_event], model)

if system_message_tokens > max_tokens:
print_verbose("`tokentrimmer`: Warning, system message exceeds token limit. Trimming...")
# shorten system message to fit within max_tokens
new_system_message = shorten_message_to_fit_limit(system_message_event, max_tokens, model)
system_message_tokens = get_token_count(new_system_message, model)
system_message_tokens = get_token_count([new_system_message], model)

return system_message_event, max_tokens - system_message_tokens

Expand All @@ -4590,11 +4638,15 @@ def process_messages(messages, max_tokens, model):
final_messages = []

for message in messages:
final_messages = attempt_message_addition(final_messages, message, max_tokens, model)
used_tokens = get_token_count(final_messages, model)
available_tokens = max_tokens - used_tokens
if available_tokens <= 3:
break
final_messages = attempt_message_addition(final_messages=final_messages, message=message, available_tokens=available_tokens, max_tokens=max_tokens, model=model)

return final_messages

def attempt_message_addition(final_messages, message, max_tokens, model):
def attempt_message_addition(final_messages, message, available_tokens, max_tokens, model):
temp_messages = [message] + final_messages
temp_message_tokens = get_token_count(messages=temp_messages, model=model)

Expand All @@ -4604,7 +4656,7 @@ def attempt_message_addition(final_messages, message, max_tokens, model):
# if temp_message_tokens > max_tokens, try shortening temp_messages
elif "function_call" not in message:
# fit updated_message to be within temp_message_tokens - max_tokens (aka the amount temp_message_tokens is greate than max_tokens)
updated_message = shorten_message_to_fit_limit(message, temp_message_tokens - max_tokens, model)
updated_message = shorten_message_to_fit_limit(message, available_tokens, model)
if can_add_message(updated_message, final_messages, max_tokens, model):
return [updated_message] + final_messages

Expand All @@ -4626,6 +4678,13 @@ def shorten_message_to_fit_limit(
"""
Shorten a message to fit within a token limit by removing characters from the middle.
"""

# For OpenAI models, even blank messages cost 7 token,
# and if the buffer is less than 3, the while loop will never end,
# hence the value 10.
if 'gpt' in model and tokens_needed <= 10:
return message

content = message["content"]

while True:
Expand Down Expand Up @@ -4674,6 +4733,7 @@ def trim_messages(
"""
# Initialize max_tokens
# if users pass in max tokens, trim to this amount
messages = copy.deepcopy(messages)
try:
print_verbose(f"trimming messages")
if max_tokens == None:
Expand All @@ -4690,6 +4750,7 @@ def trim_messages(
system_message = ""
for message in messages:
if message["role"] == "system":
system_message += '\n' if system_message else ''
system_message += message["content"]

current_tokens = token_counter(model=model, messages=messages)
Expand All @@ -4703,14 +4764,23 @@ def trim_messages(
print_verbose(f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}")
if system_message:
system_message_event, max_tokens = process_system_message(system_message=system_message, max_tokens=max_tokens, model=model)
messages = messages + [system_message_event]

if max_tokens == 0: # the system messages are too long
return [system_message_event]

# Since all system messages are combined and trimmed to fit the max_tokens,
# we remove all system messages from the messages list
messages = [message for message in messages if message["role"] != "system"]

final_messages = process_messages(messages=messages, max_tokens=max_tokens, model=model)

# Add system message to the beginning of the final messages
if system_message:
final_messages = [system_message_event] + final_messages

if return_response_tokens: # if user wants token count with new trimmed messages
response_tokens = max_tokens - get_token_count(final_messages, model)
return final_messages, response_tokens

return final_messages
except Exception as e: # [NON-Blocking, if error occurs just return final_messages
print_verbose(f"Got exception while token trimming{e}")
Expand Down

0 comments on commit fd6064b

Please sign in to comment.