Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tokenization interface #65

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions examples/generate_japanese.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
def main():

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2") # JackFram/llama-68m"
tokenizer = AutoTokenizer.from_pretrained(
"JackFram/llama-68m"
) # JackFram/llama-68m"
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("gpt2") # Load model to defined device
model = AutoModelForCausalLM.from_pretrained(
"JackFram/llama-68m"
) # Load model to defined device

# Load grammar
with open("examples/grammars/japanese.ebnf", "r") as file:
Expand All @@ -21,8 +25,8 @@ def main():
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

# Generate
prefix1 = "English: coffee, Japanese: "
prefix2 = "English: dog, Japanese: "
prefix1 = "こんにちは世界"
prefix2 = "こんにちは世界"
input_ids = tokenizer(
[prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True
)["input_ids"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def test_json_parsable(self):
)
return

acc_state = JsontokenRecognizer._consume_token_ids(token_ids, as_string=False)
acc_state = JsontokenRecognizer._update_state_with_single_token_seq(
token_ids, as_string=False
)
# the json object is complete, so the stacks should be empty
self.assertTrue(
acc_state.stacks == set() or acc_state.stacks == set(tuple()),
Expand All @@ -84,25 +86,15 @@ def test_balanced_parentheses(self):
f"unk token found in input_token_ids: {token_ids}, skipping test"
)
return

accept_state = recognizer._consume_token_ids(token_ids, as_string=False)
parsing_state = recognizer._update_state_with_single_token_seq(
token_ids, as_string=False
)
# the json object is complete, so the stacks should be empty
self.assertTrue(
accept_state.stacks == set() or accept_state.stacks == set(tuple()),
f"stacks: {accept_state.stacks}, not empty",
parsing_state.stacks == set() or parsing_state.stacks == set(tuple()),
f"stacks: {parsing_state.stacks}, not empty",
)

# inbalanced_parentheses = "((((((((()))))))))))))"
# token_ids = self.tokenizer.encode(inbalanced_parentheses)
# pprint_token_ids(self.tokenizer, token_ids)
#
# # check if there is unk token
# stacks = recognizer._consume_token_ids(
# token_ids, recognizer.grammar.stacks, as_string=False
# )
#
# self.assertTrue(stacks != [] and stacks != [[]], f"stacks: {stacks}, empty")

@unittest.skip("Not implemented")
def test_emoji(self):
"""
Expand All @@ -128,47 +120,6 @@ def test_emoji(self):
)
return

stacks = JsontokenRecognizer._consume_token_ids(
stacks = JsontokenRecognizer._update_state_with_single_token_seq(
token_ids, JsontokenRecognizer.string_recognizer.stacks, as_string=False
)

# parsed_grammar = parse_ebnf(input_text)
#
# start_rule_id = parsed_grammar.symbol_table["root"]
#
# recognizer = GrammarRecognizer(parsed_grammar.grammar_encoding, start_rule_id)
#
# self.assertTrue(recognizer._accept_string(emoji, recognizer.stacks))

# def test_beam_search_low_memory(self):
# # Check that choosing 'low_memory' does not change the model output
# for model_class in self.all_generative_model_classes:
# if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
# self.skipTest("Won't fix: old model with different cache format")
# if any(
# model_name in model_class.__name__.lower()
# for model_name in [
# "bloom",
# "ctrl",
# "gptbigcode",
# "transo_xl",
# "xlnet",
# "cpm",
# ]
# ):
# self.skipTest("May fix in the future: need model-specific fixes")
# config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=2)
# # batch_size=1 is ok, but batch_size>1 will cause non-identical output
#
# config.use_cache = True
# config.is_decoder = True
#
# # test output equality of low versus high memory
# model = model_class(config).to(torch_device).eval()
#
# low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True)
#
# high_output = model.generate(
# input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False
# )
# self.assertListEqual(low_output.tolist(), high_output.tolist())
2 changes: 1 addition & 1 deletion tests/test_grammar_constrained_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def check_parentheses(generation):
grammar_str=grammar_str, start_rule_name="root", tokenizer=tokenizer
)

accept_state = tokenRecognizer._consume_token_ids(
parsing_state = tokenRecognizer._update_state_with_single_token_seq(
input_ids[0], as_string=False
)
# generations = tokenizer.batch_decode(output, skip_special_tokens=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def test_accept_japanese(self):

head_bytes = bytes_japanese[:8]
# partial_utf8 = PartialUTF8()
accept_state = recognizer._consume_bytes(head_bytes)
parsing_state = recognizer._update_state_with_bytes(head_bytes)

# non empty stack means that the bytes were accepted
self.assertTrue(len(accept_state.stacks) > 0)
self.assertTrue(len(parsing_state.stacks) > 0)

def test_accept_japanese_progressive(self):
#######################
Expand All @@ -62,12 +62,12 @@ def test_accept_japanese_progressive(self):
# cast into bytes
byte_tokens = [bytes([byte]) for byte in byte_tokens]

accept_state = recognizer.get_initial_accept_state()
parsing_state = recognizer.get_initial_parsing_state()

# accept_state = recognizer.init_accept_state
# parsing_state = recognizer.init_parsing_state
for i, byte in enumerate(byte_tokens):
accept_state = recognizer._consume_bytes(byte, accept_state)
self.assertTrue(len(accept_state.stacks) > 0)
parsing_state = recognizer._update_state_with_bytes(byte, parsing_state)
self.assertTrue(len(parsing_state.stacks) > 0)

def test_accept_emoji(self):
"""
Expand All @@ -88,6 +88,6 @@ def test_accept_emoji(self):
# 😀😄😂

# partial_utf8 = PartialUTF8()
accept_state = recognizer._consume_bytes(bytes_emoji)
parsing_state = recognizer._update_state_with_bytes(bytes_emoji)
# non empty stack means that the bytes were accepted
self.assertTrue(len(accept_state.stacks) > 0)
self.assertTrue(len(parsing_state.stacks) > 0)
4 changes: 1 addition & 3 deletions tests/test_string_recognizer/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def test_minimal_json_object(self):
"""
json = '{"foo": "bar", "baz": "bat"}'

# accept_state = AcceptState.empty_state()

self.assertEqual(
is_json_parsable(json),
self.recognizer._accept_prefix(json),
Expand All @@ -71,7 +69,7 @@ def test_minimal_json_object(self):
def test_systematic_examples(self):

for name, json_object in json_examples.items():
# accept_state = AcceptState.empty_state()
# parsing_state = AcceptState.empty_state()
self.assertEqual(
is_json_parsable(json_object),
self.recognizer._accept_prefix(json_object),
Expand Down
1 change: 0 additions & 1 deletion tests/test_string_recognizer/test_json_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def test_minimal_json_array(self):
recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id)

for json in jsons:
# accept_state = AcceptState.empty_state()
self.assertEqual(
is_json_parsable(json),
recognizer._accept_prefix(json),
Expand Down
2 changes: 0 additions & 2 deletions tests/test_string_recognizer/test_unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def test_accept_japanese(self):

recognizer = StringRecognizer(parsed_grammar.grammar_encoding, start_rule_id)

# accept_state = AcceptState.empty_state()

self.assertTrue(recognizer._accept_prefix(japanese))

def test_emoji(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from transformers import BloomTokenizerFast

from tests._tokenizer_common import TokenizerTesterMixin
from tests._test_token_seq_recognizer_many_tokenizer_common import TokenizerTesterMixin

import logging


# @unittest.skip("GPTNeoXTokenizerFast is not available for testing")
@unittest.skip("Bloom is not supported and will be removed")
class BloomTokenizerTest(TokenizerTesterMixin, unittest.TestCase):

tokenizer_class = BloomTokenizerFast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from transformers import CodeGenTokenizerFast

from tests._tokenizer_common import TokenizerTesterMixin
from tests._test_token_seq_recognizer_many_tokenizer_common import TokenizerTesterMixin

import logging


@unittest.skip("CodeGen is not supported and will be removed")
class CodeGenTokenizerTest(TokenizerTesterMixin, unittest.TestCase):

tokenizer_class = CodeGenTokenizerFast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from transformers import PreTrainedTokenizer, AutoTokenizer

from tests._tokenizer_common import TokenizerTesterMixin
from tests._test_token_seq_recognizer_many_tokenizer_common import TokenizerTesterMixin

import logging


@unittest.skip("Falcom is not supported and will be removed")
class FalconTokenizerTest(TokenizerTesterMixin, unittest.TestCase):

tokenizer_class = AutoTokenizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from transformers import GPT2TokenizerFast

from tests._tokenizer_common import TokenizerTesterMixin
from tests._test_token_seq_recognizer_many_tokenizer_common import TokenizerTesterMixin

import logging

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from transformers import GPTNeoXTokenizerFast

from tests._tokenizer_common import TokenizerTesterMixin
from tests._test_token_seq_recognizer_many_tokenizer_common import TokenizerTesterMixin

import logging


@unittest.skip("GPTNeoX is not supported and will be removed")
class GPTNeoXTokenizerTest(TokenizerTesterMixin, unittest.TestCase):

tokenizer_class = GPTNeoXTokenizerFast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from transformers import LlamaTokenizerFast

from tests._tokenizer_common import TokenizerTesterMixin
from tests._test_token_seq_recognizer_many_tokenizer_common import TokenizerTesterMixin

import logging

Expand Down
46 changes: 46 additions & 0 deletions tests/test_token_seq_recognizer_many_tokenizers/test_t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest

from transformers import T5TokenizerFast

from tests._test_token_seq_recognizer_many_tokenizer_common import TokenizerTesterMixin

import logging


# @unittest.skip("T5Tokenizer's mapping is not well defined, not working")
class T5TokenizerTest(TokenizerTesterMixin, unittest.TestCase):

tokenizer_class = T5TokenizerFast
pretrained_name = "t5-small"

def setUp(self):
super().setUp()


class TestT5TokenizerUnkToken(unittest.TestCase):
def test_unk_token(self):
tokenizer = T5TokenizerFast.from_pretrained("t5-small")

unk_token_id = tokenizer.unk_token_id
unk_token = tokenizer.unk_token

# open curly brace is an unk token
curly_brace_open = "{"
# we take the 2nd token because the first token is the space token
curly_brace_open_id = tokenizer.encode(curly_brace_open)[1]
self.assertEqual(curly_brace_open_id, unk_token_id)

curly_brace_close = "}"
curly_brace_close_id = tokenizer.encode(curly_brace_close)[1]
self.assertEqual(curly_brace_close_id, unk_token_id)

eos_token_id = tokenizer.eos_token_id
# tab in t5 signifies the end of a line
tab = "\t"
tab_id = tokenizer.encode(tab)[0]
self.assertEqual(tab_id, eos_token_id)

# newline in t5 signifies the end of a line
newline = "\n"
newline_id = tokenizer.encode(newline)[0]
self.assertEqual(newline_id, eos_token_id)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from transformers import XGLMTokenizerFast

from tests._tokenizer_common import TokenizerTesterMixin
from tests._test_token_seq_recognizer_many_tokenizer_common import TokenizerTesterMixin

import logging


@unittest.skip("Not Supported and Will be removed")
class XGLMTokenizerTest(TokenizerTesterMixin, unittest.TestCase):

tokenizer_class = XGLMTokenizerFast
Expand Down
17 changes: 0 additions & 17 deletions tests/test_tokenizers/test_t5.py

This file was deleted.

Loading