diff --git a/outlines/fsm/parsing.py b/outlines/fsm/parsing.py index 9ebc2af55..19deb975e 100644 --- a/outlines/fsm/parsing.py +++ b/outlines/fsm/parsing.py @@ -38,6 +38,7 @@ from outlines.fsm.regex import ( fsm_union, get_sub_fsms_from_seq, + get_token_transitions, make_deterministic_fsm, walk_fsm, ) @@ -569,9 +570,15 @@ def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None) text_part = text[start_pos:] + text_transitions = get_token_transitions( + self.fsm.fsm_info.alphabet_symbol_mapping, + self.fsm.fsm_info.alphabet_anything_value, + text_part, + ) + state_seq = walk_fsm( self.fsm, - text_part, + text_transitions, start_state, full_match=self.match_whole, ) diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index b68e31897..6e2b81412 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -87,14 +87,11 @@ def fsm_info(self): ((k, z) for k, v in self.trans_key_to_states.items() for z in v), dtype=np.dtype("int64, int64"), ) - alphabet_symbol_mapping_items = np.fromiter( - ( - it - for it in self.alphabet._symbol_mapping.items() - if it[0] != anything_else - ), - dtype=np.dtype("U2, int64"), - ) + alphabet_symbol_mapping_items = [ + (k, v) + for k, v in self.alphabet._symbol_mapping.items() + if k != anything_else + ] nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64")) self.__dict__["_fsm_info"] = create_fsm_info( self.initial, @@ -110,7 +107,7 @@ def fsm_info(self): nb_int_list_type = numba.types.ListType(numba.int64) nb_int_pair_type = numba.types.UniTuple(numba.int64, 2) -nb_unichar_2_type = numba.types.UnicodeCharSeq(2) +nb_unicode_type = numba.types.unicode_type @numba.njit(cache=True) @@ -136,7 +133,7 @@ def create_fsm_info( # use 2-char strings so that we can represent incomplete utf-8 sequences # as 2-hex-digit pairs - alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_2_type, numba.int64) + alphabet_symbol_map = numba.typed.Dict.empty(nb_unicode_type, numba.int64) for symbol_and_trans_key in alphabet_symbol_mapping_items: alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1] @@ -199,7 +196,7 @@ def transition_trie_setdefault( def byte_symbol(byte: int) -> str: - return f"{byte:02X}" if byte >= 0x80 else chr(byte) + return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte) def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM: @@ -415,11 +412,9 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: @numba.njit(nogil=True, cache=True) def _walk_fsm( fsm_transitions: Dict[Tuple[int, int], int], - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, fsm_initial: int, fsm_finals: Set[int], - input_string: Sequence[str], + token_trans_key_seq: Sequence[int], start_state: int, full_match: bool = True, ) -> List[int]: @@ -427,9 +422,9 @@ def _walk_fsm( accepted_states: List[int] = numba.typed.List.empty_list(numba.int64) last_final_idx: int = numba.uint64(0) - for i, symbol in enumerate(input_string): - trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value) - + # Iterate over token transition key sequence. The transition key + # sequence represents the FSM traversal rules of the tokens symbols. + for i, trans_key in enumerate(token_trans_key_seq): new_state = fsm_transitions.get((state, trans_key)) if new_state is None: @@ -453,7 +448,7 @@ def _walk_fsm( def walk_fsm( fsm: BetterFSM, - input_string: Sequence[str], + token_trans_key_seq: Sequence[int], start_state: int, full_match: bool = True, ) -> List[int]: @@ -463,13 +458,11 @@ def walk_fsm( accepted_states: List[int] = [] last_final_idx: int = 0 - alphabet_symbol_mapping = fsm.alphabet._symbol_mapping - alphabet_anything_value = fsm.alphabet.anything_value fsm_transitions = fsm.flat_transition_map - for i, symbol in enumerate(input_string): - trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value) - + # Iterate over token transition key sequence. The transition key + # sequence represents the FSM traversal rules of the tokens symbols. + for i, trans_key in enumerate(token_trans_key_seq): new_state = fsm_transitions.get((state, trans_key)) if new_state is None: @@ -655,24 +648,25 @@ def state_scan_tokens( alphabet_anything_value: int, fsm_initial: int, fsm_finals: Set[int], - vocabulary: List[Tuple[Sequence[str], Sequence[int]]], + vocabulary: List[Tuple[str, Sequence[int]]], + token_trans_key_seqs: List[Sequence[int]], start_state: int, ) -> Set[Tuple[int, int]]: res = set() - for token, token_ids in vocabulary: + for (token, token_ids), token_trans_key_seq in zip( + vocabulary, token_trans_key_seqs + ): state_seq = _walk_fsm( fsm_transitions, - alphabet_symbol_mapping, - alphabet_anything_value, fsm_initial, fsm_finals, - token, + token_trans_key_seq, start_state, False, ) - if state_seq is not None and len(state_seq) < len(token): + if state_seq is not None and len(state_seq) < len(token_trans_key_seq): continue for token_id in token_ids: @@ -681,9 +675,51 @@ def state_scan_tokens( return res +@numba.njit(cache=True, nogil=True) +def get_token_transitions( + alphabet_symbol_mapping: Dict[str, int], + alphabet_anything_value: int, + token_str: str, +) -> Sequence[int]: + trans_key_seq = [] + i = 0 + while i < len(token_str): + if token_str[i] == "\x00" and i != len(token_str) - 1: + symbol = token_str[i : i + 3] + i += 3 + else: + symbol = token_str[i] + i += 1 + + trans_key_seq.append( + alphabet_symbol_mapping.get(symbol, alphabet_anything_value) + ) + + trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64) + for j in range(len(trans_key_seq)): + trans_key_seq_array[j] = trans_key_seq[j] + return trans_key_seq_array + + +@numba.njit(cache=True, nogil=True) +def get_tokens_trans_keys( + alphabet_symbol_mapping: Dict[str, int], + alphabet_anything_value: int, + vocabulary: List[Tuple[str, Sequence[int]]], +) -> List[Sequence[int]]: + tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:]) + for token_str, _ in vocabulary: + trans_key_seq_array = get_token_transitions( + alphabet_symbol_mapping, alphabet_anything_value, token_str + ) + tokens_trans_keys.append(trans_key_seq_array) + + return tokens_trans_keys + + def create_fsm_index_end_to_end( fsm_info: FSMInfo, - vocabulary: List[Tuple[Sequence[str], Sequence[int]]], + vocabulary: List[Tuple[str, Sequence[int]]], ) -> Dict[int, Set[Tuple[int, int]]]: """Create an FSM state-to-vocabulary map/index through end-to-end token parsing.""" @@ -699,6 +735,12 @@ def create_fsm_index_end_to_end( desc="Compiling FSM index for all state transitions", ) + tokens_trans_key_seqs = get_tokens_trans_keys( + fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + vocabulary, + ) + while next_states: start_state = next_states.pop() @@ -709,6 +751,7 @@ def create_fsm_index_end_to_end( fsm_info.initial, fsm_info.finals, vocabulary, + tokens_trans_key_seqs, start_state, ) @@ -771,7 +814,7 @@ def gpt2_unicode_to_bytes(): @lru_cache def reduced_vocabulary( tokenizer: "Tokenizer", -) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]: +) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]: """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" empty_token_ids = set() vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {} @@ -804,7 +847,7 @@ def reduced_vocabulary( raise RuntimeError( f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}" ) - token_str = tuple(byte_symbol(b) for b in token_bytes) + token_str = "".join(byte_symbol(b) for b in token_bytes) vocabulary.setdefault(token_str, []).append(token_idx) else: @@ -813,15 +856,14 @@ def reduced_vocabulary( vocabulary_nb = numba.typed.List.empty_list( numba.types.Tuple( ( - nb_unichar_2_type[:], + nb_unicode_type, numba.int64[:], ) ) ) - for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + for token_str, token_ids in vocabulary.items(): token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_tuple_np, token_ids_np)) + vocabulary_nb.append((token_str, token_ids_np)) return vocabulary_nb, empty_token_ids diff --git a/tests/fsm/test_parsing.py b/tests/fsm/test_parsing.py index 4e093a994..b624fddee 100644 --- a/tests/fsm/test_parsing.py +++ b/tests/fsm/test_parsing.py @@ -9,7 +9,14 @@ from outlines.fsm.parsing import PartialLark, PartialPythonIndenter -def test_partial_parsing(): +@pytest.fixture +def cleanup_lark_import(): + yield + # Clean up lark.lark.LarkOptions._defaults + importlib.reload(lark.lark) + + +def test_partial_parsing(cleanup_lark_import): lp = PartialLark.open_from_package( "tests", "partial_python.lark", @@ -136,11 +143,8 @@ def test_partial_parsing(): assert len(parser_state.state_stack) == 4 assert parser_state.value_stack[-1].type == "LPAR" - # Clean up lark.lark.LarkOptions._defaults - importlib.reload(lark.lark) - -def test_sequential_parse_example(): +def test_sequential_parse_example(cleanup_lark_import): input_tokens = [ "x ", "= ", diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index 2fc8a5384..1e14182b5 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -1,5 +1,3 @@ -from typing import Sequence - import interegular import numba import numpy as np @@ -12,9 +10,11 @@ create_fsm_index_tokenizer, fsm_union, get_sub_fsms_from_seq, + get_tokens_trans_keys, make_byte_level_better_fsm, make_byte_level_fsm, make_deterministic_fsm, + reduced_vocabulary, walk_fsm, ) from outlines.models.transformers import TransformerTokenizer @@ -25,22 +25,50 @@ def identity(s): def to_bytes(s): - return [chr(b) if b < 0x80 else f"{b:02X}" for b in s.encode("utf-8")] + return [chr(b) if b < 0x80 else f"\x00{b:02X}" for b in s.encode("utf-8")] + + +def merge_symbols(byte_hexs): + return "".join(["\x00" + b if len(b) == 2 else b for b in byte_hexs]) + + +def token_str_to_trans_key(fsm, input_string): + vocabulary_nb = numba.typed.List.empty_list( + numba.types.Tuple((numba.types.unicode_type, numba.int64[:])) + ) + vocabulary_nb.append((input_string, np.fromiter([1], dtype=np.dtype("int64")))) + return get_tokens_trans_keys( + fsm.fsm_info.alphabet_symbol_mapping, + fsm.fsm_info.alphabet_anything_value, + vocabulary_nb, + )[0] -def walk_fsm_numba( +def walk_fsm_from_token_str( fsm, - input_string: Sequence[str], + input_string: str, + start_state: int, + full_match: bool = True, +): + return walk_fsm( + fsm, + token_str_to_trans_key(fsm, input_string), + start_state, + full_match, + ) + + +def walk_fsm_from_token_str_numba( + fsm, + input_string: str, start_state: int, full_match: bool = True, ): return _walk_fsm( fsm.fsm_info.transitions, - fsm.fsm_info.alphabet_symbol_mapping, - fsm.fsm_info.alphabet_anything_value, fsm.fsm_info.initial, fsm.fsm_info.finals, - input_string, + token_str_to_trans_key(fsm, input_string), start_state, full_match=full_match, ) @@ -49,8 +77,8 @@ def walk_fsm_numba( @pytest.mark.parametrize( "function", [ - walk_fsm, - walk_fsm_numba, + walk_fsm_from_token_str, + walk_fsm_from_token_str_numba, ], ) def test_walk_fsm(function): @@ -99,8 +127,8 @@ def test_walk_fsm(function): @pytest.mark.parametrize( "function", [ - walk_fsm, - walk_fsm_numba, + walk_fsm_from_token_str, + walk_fsm_from_token_str_numba, ], ) @pytest.mark.parametrize( @@ -115,19 +143,37 @@ def test_walk_fsm_multi_bytes(function, transform): str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) - res = tuple(function(regex_fsm, transform("πŸ˜‚"), regex_fsm.initial, full_match=True)) + res = tuple( + function( + regex_fsm, merge_symbols(transform("πŸ˜‚")), regex_fsm.initial, full_match=True + ) + ) assert res[-1:] == (1,) res = tuple( - function(regex_fsm, transform("πŸ˜‚πŸ˜‚"), regex_fsm.initial, full_match=False) + function( + regex_fsm, + merge_symbols(transform("πŸ˜‚πŸ˜‚")), + regex_fsm.initial, + full_match=False, + ) ) assert res[-1:] == (1,) - res = tuple(function(regex_fsm, transform("!"), regex_fsm.initial, full_match=True)) + res = tuple( + function( + regex_fsm, merge_symbols(transform("!")), regex_fsm.initial, full_match=True + ) + ) assert res == tuple() res = tuple( - function(regex_fsm, transform("πŸ˜‚πŸ˜‚"), regex_fsm.initial, full_match=True) + function( + regex_fsm, + merge_symbols(transform("πŸ˜‚πŸ˜‚")), + regex_fsm.initial, + full_match=True, + ) ) assert res == tuple() @@ -194,14 +240,14 @@ def test_get_sub_fsms_from_seq(): assert fsm.accepts("+=") assert fsm.accepts("+") - state_seq = walk_fsm(fsm, "def", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "def", fsm.initial) state_seq.insert(0, fsm.fsm_info.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(0, False, True), (2, True, True)] # Make sure the old-to-new state map is correct - def_state_seq = walk_fsm(def_fsm, "def", fsm.initial) + def_state_seq = walk_fsm_from_token_str(def_fsm, "def", fsm.initial) def_state_seq.insert(0, fsm.fsm_info.initial) def_old_to_new_states = fsms_to_trans_finals[0][2] @@ -210,13 +256,13 @@ def test_get_sub_fsms_from_seq(): for old_state, new_state in zip(def_state_seq, state_seq) ) - state_seq = walk_fsm(fsm, "ef", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "ef", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(2, True, True)] - name_state_seq = walk_fsm(name_fsm, "ef", fsm.initial) + name_state_seq = walk_fsm_from_token_str(name_fsm, "ef", fsm.initial) name_state_seq.insert(0, fsm.initial) name_old_to_new_states = fsms_to_trans_finals[2][2] @@ -225,13 +271,13 @@ def test_get_sub_fsms_from_seq(): for old_state, new_state in zip(name_state_seq, state_seq) ) - state_seq = walk_fsm(fsm, "match", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "match", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(1, False, True), (2, True, True)] - match_state_seq = walk_fsm(match_fsm, "match", fsm.initial) + match_state_seq = walk_fsm_from_token_str(match_fsm, "match", fsm.initial) match_state_seq.insert(0, fsm.initial) match_old_to_new_states = fsms_to_trans_finals[1][2] @@ -240,25 +286,25 @@ def test_get_sub_fsms_from_seq(): for old_state, new_state in zip(match_state_seq, state_seq) ) - state_seq = walk_fsm(fsm, "defa", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "defa", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(2, True, True)] - state_seq = walk_fsm(fsm, "de", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "de", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(0, True, False), (2, True, True)] - state_seq = walk_fsm(fsm, "+", fsm.initial, False) + state_seq = walk_fsm_from_token_str(fsm, "+", fsm.initial, False) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(3, True, False), (4, False, True)] - state_seq = walk_fsm(fsm, "+=", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "+=", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) @@ -304,15 +350,15 @@ def test_create_fsm_index_end_to_end(): vocabulary_nb = numba.typed.List.empty_list( numba.types.Tuple( ( - numba.types.UnicodeCharSeq(2)[:], + numba.types.unicode_type, numba.int64[:], ) ) ) for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + token = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_tuple_np, token_ids_np)) + vocabulary_nb.append((token, token_ids_np)) res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb) @@ -331,23 +377,25 @@ def test_create_fsm_index_end_to_end_multi_byte(): "😈a": numba.typed.List([1]), "πŸ˜‡": numba.typed.List([2]), "😍": numba.typed.List([3]), - ("F0", "9F", "98", "8D"): numba.typed.List([4]), # '😍' + merge_symbols(("F0", "9F", "98", "8D")): numba.typed.List([4]), # '😍' " 😍": numba.typed.List([5]), - (" ", "F0", "9F", "98", "8D"): numba.typed.List([6]), # ' 😍' - (" ", "F0", "9F", "98"): numba.typed.List([7]), # ' 😍' incomplete + merge_symbols((" ", "F0", "9F", "98", "8D")): numba.typed.List([6]), # ' 😍' + merge_symbols((" ", "F0", "9F", "98")): numba.typed.List( + [7] + ), # ' 😍' incomplete "": numba.typed.List([8]), } vocabulary_nb = numba.typed.List.empty_list( numba.types.Tuple( ( - numba.types.UnicodeCharSeq(2)[:], + numba.types.unicode_type, numba.int64[:], ) ) ) for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + token_tuple_np = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) vocabulary_nb.append((token_tuple_np, token_ids_np)) @@ -356,7 +404,16 @@ def test_create_fsm_index_end_to_end_multi_byte(): assert res == {0: {(5, 3), (6, 3), (7, 7), (2, 2)}, 3: {(2, 3), (3, 3), (4, 3)}} -def test_create_fsm_index_tokenizer(): +@pytest.mark.parametrize( + "hf_tokenizer_uri", + [ + "gpt2", + "microsoft/phi-2", + "Qwen/Qwen1.5-0.5B-Chat", + "NousResearch/Hermes-2-Pro-Llama-3-8B", + ], +) +def test_create_fsm_index_tokenizer(hf_tokenizer_uri): # The combined regular expressions of a lexer state in a Python grammar regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" @@ -371,7 +428,7 @@ def test_create_fsm_index_tokenizer(): num_bytes_fsm_states = len(bytes_fsm.states) assert num_bytes_fsm_states == 235 - tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri) tokenizer = TransformerTokenizer(tokenizer) states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer( @@ -521,3 +578,112 @@ def build_regex(): ) profiler.dump_stats("line-profiler-build-json-regex.pkl") profiler.print_stats(output_unit=1e-3, summarize=True, stripzeros=True) + + +def test_token_trans_keys_identical(): + """assert two tokens w/ identical behavior wrt FSM have same trans key seq""" + + class MockTokenizer: + vocabulary = {"a": 1, "b": 2, "z": 3, "eos": 4} + special_tokens = {"eos"} + eos_token_id = 4 + + def convert_token_to_string(self, token): + return token + + tokenizer = MockTokenizer() + + pattern = r"z[ab]z" + regex_pattern = interegular.parse_pattern(pattern) + interegular_fsm = regex_pattern.to_fsm().reduce() + regex_fsm, _ = make_deterministic_fsm(interegular_fsm) + vocabulary, _ = reduced_vocabulary(tokenizer) + token_trans_keys = get_tokens_trans_keys( + regex_fsm.fsm_info.alphabet_symbol_mapping, + regex_fsm.fsm_info.alphabet_anything_value, + vocabulary, + ) + + token_str_trans_key_seq = { + token_str: trans_key_seq + for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) + } + # `a` and `b` both are workable, but `z` has distinct transition rules + assert interegular_fsm.accepts("zaz") + assert interegular_fsm.accepts("zbz") + assert (token_str_trans_key_seq["a"] == token_str_trans_key_seq["b"]).all() + assert not (token_str_trans_key_seq["a"] == token_str_trans_key_seq["z"]).all() + + +def test_token_trans_keys_walk_fsm(): + """assert _walk_fsm works using transition keys""" + + class MockTokenizer: + vocabulary = {"ab": 1, "ac": 2, "az": 3, "eos": 4} + special_tokens = {"eos"} + eos_token_id = 4 + + def convert_token_to_string(self, token): + return token + + tokenizer = MockTokenizer() + + pattern = r"a[bc]z" + regex_pattern = interegular.parse_pattern(pattern) + interegular_fsm = regex_pattern.to_fsm().reduce() + regex_fsm, _ = make_deterministic_fsm(interegular_fsm) + vocabulary, _ = reduced_vocabulary(tokenizer) + token_trans_keys = get_tokens_trans_keys( + regex_fsm.fsm_info.alphabet_symbol_mapping, + regex_fsm.fsm_info.alphabet_anything_value, + vocabulary, + ) + + token_str_trans_key_seq = { + token_str: trans_key_seq + for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) + } + + # verify initial state valid only for "ab" and "ac" using transition key seq + token_acceptance = {"ab": True, "ac": True, "az": False} + for token, should_accept in token_acceptance.items(): + token_trans_key_seq = token_str_trans_key_seq[token] + state_seq = _walk_fsm( + regex_fsm.fsm_info.transitions, + regex_fsm.fsm_info.initial, + regex_fsm.fsm_info.finals, + token_trans_key_seq, + regex_fsm.fsm_info.initial, + False, + ) + is_accepted = len(state_seq) >= len(token_trans_key_seq) + assert should_accept == is_accepted + + +def test_numba_leading_null_byte_UnicodeCharSeq_remains_broken(): + """Assert numba UnicodeCharSeq w/ leading \x00 is still broken""" + # EXPLANATION: + # https://github.com/outlines-dev/outlines/pull/930#issuecomment-2143535968 + + # from https://github.com/numba/numba/issues/9542 + d = numba.typed.typeddict.Dict.empty(numba.types.UnicodeCharSeq(1), numba.int64) + d["δΈ€"] = 10 # \xe4\xb8\x80 + with pytest.raises(KeyError): + str(d) + + # most characters are fine, but "\x00" is converted to "" + l = np.fromiter(["\x99", "\x00"], dtype=np.dtype("U2")) + assert str(l[0]) == "\x99" # fine + assert str(l[1]) == "" # 1-byte null converted to 0-bytes + + +@pytest.mark.parametrize("input_key", ["δΈ€", "\x00"]) +def test_numba_leading_null_byte_unicode_type_sane(input_key): + """Assert numba unicode_type w/ leading \x00 is working""" + # EXPLANATION: + # https://github.com/outlines-dev/outlines/pull/930#issuecomment-2143535968 + + # from https://github.com/numba/numba/issues/9542 + d = numba.typed.typeddict.Dict.empty(numba.types.unicode_type, numba.int64) + d["δΈ€"] = 10 # \xe4\xb8\x80 + str(d) # assert successfully interprets