Skip to content

Commit

Permalink
Merge a759728 into 0b4d12b
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 authored Jun 4, 2024
2 parents 0b4d12b + a759728 commit fe5cb3d
Show file tree
Hide file tree
Showing 4 changed files with 297 additions and 78 deletions.
9 changes: 8 additions & 1 deletion outlines/fsm/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from outlines.fsm.regex import (
fsm_union,
get_sub_fsms_from_seq,
get_token_transitions,
make_deterministic_fsm,
walk_fsm,
)
Expand Down Expand Up @@ -569,9 +570,15 @@ def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None)

text_part = text[start_pos:]

text_transitions = get_token_transitions(
self.fsm.fsm_info.alphabet_symbol_mapping,
self.fsm.fsm_info.alphabet_anything_value,
text_part,
)

state_seq = walk_fsm(
self.fsm,
text_part,
text_transitions,
start_state,
full_match=self.match_whole,
)
Expand Down
114 changes: 78 additions & 36 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,11 @@ def fsm_info(self):
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
dtype=np.dtype("int64, int64"),
)
alphabet_symbol_mapping_items = np.fromiter(
(
it
for it in self.alphabet._symbol_mapping.items()
if it[0] != anything_else
),
dtype=np.dtype("U2, int64"),
)
alphabet_symbol_mapping_items = [
(k, v)
for k, v in self.alphabet._symbol_mapping.items()
if k != anything_else
]
nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64"))
self.__dict__["_fsm_info"] = create_fsm_info(
self.initial,
Expand All @@ -110,7 +107,7 @@ def fsm_info(self):

nb_int_list_type = numba.types.ListType(numba.int64)
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
nb_unichar_2_type = numba.types.UnicodeCharSeq(2)
nb_unicode_type = numba.types.unicode_type


@numba.njit(cache=True)
Expand All @@ -136,7 +133,7 @@ def create_fsm_info(

# use 2-char strings so that we can represent incomplete utf-8 sequences
# as 2-hex-digit pairs
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_2_type, numba.int64)
alphabet_symbol_map = numba.typed.Dict.empty(nb_unicode_type, numba.int64)
for symbol_and_trans_key in alphabet_symbol_mapping_items:
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]

Expand Down Expand Up @@ -199,7 +196,7 @@ def transition_trie_setdefault(


def byte_symbol(byte: int) -> str:
return f"{byte:02X}" if byte >= 0x80 else chr(byte)
return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte)


def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM:
Expand Down Expand Up @@ -415,21 +412,19 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
@numba.njit(nogil=True, cache=True)
def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: Sequence[str],
token_trans_key_seq: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
state = start_state
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
last_final_idx: int = numba.uint64(0)

for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand All @@ -453,7 +448,7 @@ def _walk_fsm(

def walk_fsm(
fsm: BetterFSM,
input_string: Sequence[str],
token_trans_key_seq: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand All @@ -463,13 +458,11 @@ def walk_fsm(
accepted_states: List[int] = []
last_final_idx: int = 0

alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
alphabet_anything_value = fsm.alphabet.anything_value
fsm_transitions = fsm.flat_transition_map

for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand Down Expand Up @@ -655,24 +648,25 @@ def state_scan_tokens(
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
vocabulary: List[Tuple[str, Sequence[int]]],
token_trans_key_seqs: List[Sequence[int]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()

for token, token_ids in vocabulary:
for (token, token_ids), token_trans_key_seq in zip(
vocabulary, token_trans_key_seqs
):
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
alphabet_anything_value,
fsm_initial,
fsm_finals,
token,
token_trans_key_seq,
start_state,
False,
)

if state_seq is not None and len(state_seq) < len(token):
if state_seq is not None and len(state_seq) < len(token_trans_key_seq):
continue

for token_id in token_ids:
Expand All @@ -681,9 +675,51 @@ def state_scan_tokens(
return res


@numba.njit(cache=True, nogil=True)
def get_token_transitions(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
token_str: str,
) -> Sequence[int]:
trans_key_seq = []
i = 0
while i < len(token_str):
if token_str[i] == "\x00" and i != len(token_str) - 1:
symbol = token_str[i : i + 3]
i += 3
else:
symbol = token_str[i]
i += 1

trans_key_seq.append(
alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
)

trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64)
for j in range(len(trans_key_seq)):
trans_key_seq_array[j] = trans_key_seq[j]
return trans_key_seq_array


@numba.njit(cache=True, nogil=True)
def get_tokens_trans_keys(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
vocabulary: List[Tuple[str, Sequence[int]]],
) -> List[Sequence[int]]:
tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:])
for token_str, _ in vocabulary:
trans_key_seq_array = get_token_transitions(
alphabet_symbol_mapping, alphabet_anything_value, token_str
)
tokens_trans_keys.append(trans_key_seq_array)

return tokens_trans_keys


def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
vocabulary: List[Tuple[str, Sequence[int]]],
) -> Dict[int, Set[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""

Expand All @@ -699,6 +735,12 @@ def create_fsm_index_end_to_end(
desc="Compiling FSM index for all state transitions",
)

tokens_trans_key_seqs = get_tokens_trans_keys(
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
vocabulary,
)

while next_states:
start_state = next_states.pop()

Expand All @@ -709,6 +751,7 @@ def create_fsm_index_end_to_end(
fsm_info.initial,
fsm_info.finals,
vocabulary,
tokens_trans_key_seqs,
start_state,
)

Expand Down Expand Up @@ -771,7 +814,7 @@ def gpt2_unicode_to_bytes():
@lru_cache
def reduced_vocabulary(
tokenizer: "Tokenizer",
) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]:
) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]:
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
empty_token_ids = set()
vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {}
Expand Down Expand Up @@ -804,7 +847,7 @@ def reduced_vocabulary(
raise RuntimeError(
f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}"
)
token_str = tuple(byte_symbol(b) for b in token_bytes)
token_str = "".join(byte_symbol(b) for b in token_bytes)

vocabulary.setdefault(token_str, []).append(token_idx)
else:
Expand All @@ -813,15 +856,14 @@ def reduced_vocabulary(
vocabulary_nb = numba.typed.List.empty_list(
numba.types.Tuple(
(
nb_unichar_2_type[:],
nb_unicode_type,
numba.int64[:],
)
)
)
for token_tuple, token_ids in vocabulary.items():
token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2"))
for token_str, token_ids in vocabulary.items():
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
vocabulary_nb.append((token_tuple_np, token_ids_np))
vocabulary_nb.append((token_str, token_ids_np))

return vocabulary_nb, empty_token_ids

Expand Down
14 changes: 9 additions & 5 deletions tests/fsm/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from outlines.fsm.parsing import PartialLark, PartialPythonIndenter


def test_partial_parsing():
@pytest.fixture
def cleanup_lark_import():
yield
# Clean up lark.lark.LarkOptions._defaults
importlib.reload(lark.lark)


def test_partial_parsing(cleanup_lark_import):
lp = PartialLark.open_from_package(
"tests",
"partial_python.lark",
Expand Down Expand Up @@ -136,11 +143,8 @@ def test_partial_parsing():
assert len(parser_state.state_stack) == 4
assert parser_state.value_stack[-1].type == "LPAR"

# Clean up lark.lark.LarkOptions._defaults
importlib.reload(lark.lark)


def test_sequential_parse_example():
def test_sequential_parse_example(cleanup_lark_import):
input_tokens = [
"x ",
"= ",
Expand Down
Loading

0 comments on commit fe5cb3d

Please sign in to comment.