-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update tokenizer.py to support HF (#42)
- Loading branch information
1 parent
3b0fe3f
commit 36a0812
Showing
1 changed file
with
66 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,83 @@ | ||
import os | ||
|
||
_enc = None | ||
|
||
_voyageai = None | ||
_voyage_clients = {} | ||
_tiktoken_encoders = {} | ||
_hf_tokenizers = {} | ||
|
||
|
||
def count_tokens(content: str, model: str = "gpt-3.5-turbo") -> int: | ||
global _enc, _voyageai | ||
""" | ||
Count tokens in `content` based on the model name. | ||
1) If `model.startswith("voyage")`, use VoyageAI client. | ||
2) Otherwise, try to use `tiktoken`. If tiktoken raises KeyError | ||
(unrecognized model name), fallback to Hugging Face tokenizer. | ||
""" | ||
if model.startswith("voyage"): | ||
if _voyageai is None: | ||
voyageai_import_err = "`voyageai` package not found, please run `pip install voyageai`" | ||
if model not in _voyage_clients: | ||
# Lazy-import VoyageAI & create a client | ||
try: | ||
import voyageai | ||
except ImportError as e: | ||
raise ImportError(voyageai_import_err) from e | ||
raise ImportError( | ||
"`voyageai` package not found, please run `pip install voyageai`" | ||
) from e | ||
|
||
_voyage_clients[model] = voyageai.Client() | ||
|
||
# VoyageAI expects a list of texts | ||
return _voyage_clients[model].count_tokens([content]) | ||
|
||
try: | ||
import tiktoken | ||
|
||
_voyageai = voyageai.Client() | ||
# If we don't already have a tiktoken encoder for this model, | ||
# try to fetch it. This will raise KeyError if the model name | ||
# is unknown (e.g. a Hugging Face model). | ||
if model not in _tiktoken_encoders: | ||
# Temporarily set TIKTOKEN_CACHE_DIR if not present | ||
should_revert = False | ||
if "TIKTOKEN_CACHE_DIR" not in os.environ: | ||
should_revert = True | ||
os.environ["TIKTOKEN_CACHE_DIR"] = os.path.join( | ||
os.path.dirname(os.path.abspath(__file__)), | ||
"_static", | ||
"tiktoken_cache" | ||
) | ||
|
||
return _voyageai.count_tokens([content]) | ||
_tiktoken_encoders[model] = tiktoken.encoding_for_model(model) | ||
|
||
if _enc is None: | ||
tiktoken_import_err = "`tiktoken` package not found, please run `pip install tiktoken`" | ||
try: | ||
import tiktoken | ||
except ImportError as e: | ||
raise ImportError(tiktoken_import_err) from e | ||
# Clean up TIKTOKEN_CACHE_DIR if we set it | ||
if should_revert: | ||
del os.environ["TIKTOKEN_CACHE_DIR"] | ||
|
||
# set tokenizer cache temporarily | ||
should_revert = False | ||
if "TIKTOKEN_CACHE_DIR" not in os.environ: | ||
should_revert = True | ||
os.environ["TIKTOKEN_CACHE_DIR"] = os.path.join( | ||
os.path.dirname(os.path.abspath(__file__)), | ||
"_static/tiktoken_cache", | ||
) | ||
# Now we can encode | ||
encoder = _tiktoken_encoders[model] | ||
return len(encoder.encode(content, allowed_special="all")) | ||
|
||
_enc = tiktoken.encoding_for_model(model) | ||
except ImportError as e: | ||
# tiktoken isn't installed at all | ||
raise ImportError( | ||
"`tiktoken` package not found, please run `pip install tiktoken`" | ||
) from e | ||
|
||
except KeyError: | ||
""" | ||
tiktoken raised KeyError, meaning it does not recognize this `model` | ||
as a valid OpenAI model. We will fallback to a Hugging Face tokenizer. | ||
""" | ||
|
||
if model not in _hf_tokenizers: | ||
try: | ||
from transformers import AutoTokenizer | ||
except ImportError as e: | ||
raise ImportError( | ||
"Hugging Face `transformers` not found. " | ||
"Please install via `pip install transformers`." | ||
) from e | ||
|
||
if should_revert: | ||
del os.environ["TIKTOKEN_CACHE_DIR"] | ||
_hf_tokenizers[model] = AutoTokenizer.from_pretrained(model) | ||
|
||
return len(_enc.encode(content, allowed_special="all")) | ||
hf_tokenizer = _hf_tokenizers[model] | ||
# For HF, simply return the length of the encoded IDs | ||
return len(hf_tokenizer.encode(content)) |