diff --git a/docs/contribute.md b/docs/contribute.md new file mode 100644 index 0000000..0f8dd76 --- /dev/null +++ b/docs/contribute.md @@ -0,0 +1,12 @@ +# Contribute + +We welcome contributions to the project. + +To contribute, please follow these steps: +1. Fork the repository. +2. Create a new branch for your changes with a descriptive name, e.g. `git checkout -b feature/add-support-for-xyz` or `git checkout -b fix/parsing-error-in-abc`. +3. Create an environment with the required dependencies via `pip install -r requirements.txt`. +4. Install `pre-commit` hooks via `pre-commit install`. +5. Make your changes and add tests to ensure your changes are correct. +6. Commit them, `pre-commit` will run automatically when you commit. Tests will be run to ensure your changes are correct. +7. If all tests pass, push your changes to your fork and create a pull request. diff --git a/examples/generate_geo_query.py b/examples/generate_geo_query.py index adfd8e2..cf38296 100644 --- a/examples/generate_geo_query.py +++ b/examples/generate_geo_query.py @@ -6,9 +6,15 @@ from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor from transformers_cfg.parser import parse_ebnf + def parse_args(): parser = argparse.ArgumentParser(description="Generate geo query strings") - parser.add_argument("--model-id", type=str, default="/dlabdata1/llm_hub/Mistral-7B-v0.1", help="Model ID") + parser.add_argument( + "--model-id", + type=str, + default="/dlabdata1/llm_hub/Mistral-7B-v0.1", + help="Model ID", + ) parser.add_argument("--device", type=str, help="Device to put the model on") return parser.parse_args() @@ -18,14 +24,16 @@ def main(): model_id = args.model_id # Detect if GPU is available, otherwise use CPU - device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) + device = torch.device( + args.device or ("cuda" if torch.cuda.is_available() else "cpu") + ) print(f"Using device: {device}") # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token # Load model to defined device - model = AutoModelForCausalLM.from_pretrained(model_id).to(device) + model = AutoModelForCausalLM.from_pretrained(model_id).to(device) # Load grammar with open(f"examples/grammars/geo_query.ebnf", "r") as file: @@ -40,7 +48,7 @@ def main(): "which state contains most rivers ? ", "number of citizens in boulder ? ", "what are the major cities of the us ? ", - "what is the smallest city in washington ? ", + "what is the smallest city in washington ? ", "how many states border colorado and border new mexico ? ", ] @@ -49,7 +57,7 @@ def main(): )["input_ids"].to( device ) # Move input_ids to the same device as model - + n_examples = input_ids.shape[0] max_new_tokens = 50 @@ -70,27 +78,35 @@ def main(): ) parsed_grammar = parse_ebnf(grammar_str) - string_grammar = StringRecognizer(parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"]) - + string_grammar = StringRecognizer( + parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"] + ) + # decode outputs (possibly of different lengths across decoding modes) - generations = tokenizer.batch_decode(unconstrained_output, skip_special_tokens=True) + \ - tokenizer.batch_decode(constrained_output, skip_special_tokens=True) + generations = tokenizer.batch_decode( + unconstrained_output, skip_special_tokens=True + ) + tokenizer.batch_decode(constrained_output, skip_special_tokens=True) print() for i in range(n_examples): print(f"Unconstrained: {generations[i]}") constrained_generation = generations[i + n_examples] print(f"Constrained: {generations[i + n_examples]}") - print(f'The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}') - print(f'The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}') + print( + f"The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}" + ) + print( + f"The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}" + ) print() + if __name__ == "__main__": main() ########################## -# Example output: -# +# Example output: +# # Unconstrained: how many states border colorado and border new mexico ? 1. # - How long is the drive from denver to albuquerque? The distance between Denver, Colorado (CO) & Alburqueque New Mexico(NM). Driving directions for your road trip or vacation: Get driving # Constrained: how many states border colorado and border new mexico ? answer(smallest_one(area_1(stateid('colorado')))) diff --git a/tests/test_parsing/test_parsing.py b/tests/test_parsing/test_parsing.py index aa35f32..6ef931d 100644 --- a/tests/test_parsing/test_parsing.py +++ b/tests/test_parsing/test_parsing.py @@ -138,7 +138,6 @@ def test__parse_literal_string(self): self.assertEqual(3, len(outbuf), f"len(outbuf): {len(outbuf)} != 3") self.assertListEqual([2, ord("你"), ord("你")], outbuf) - def test__parse_escape(self): escaped_char_src = '"\\n"' outbuf = [] diff --git a/transformers_cfg/generation/logits_process.py b/transformers_cfg/generation/logits_process.py index 212a541..2cabb3a 100644 --- a/transformers_cfg/generation/logits_process.py +++ b/transformers_cfg/generation/logits_process.py @@ -89,7 +89,7 @@ def process_logits(self, input_ids, scores): ) # logger.debug("stacks: \n" + pprint.pformat(self.batch_accept_states.stacks)) - self.batch_accept_states = self.grammar_constraint.advance_token_ids( + self.batch_accept_states = self.grammar_constraint.consume_token_ids( input_ids, self.batch_accept_states, self.parse_start_index ) logger.debug(f"input_ids: {input_ids}") diff --git a/transformers_cfg/recognizer.py b/transformers_cfg/recognizer.py index 3a65210..dfbc92e 100644 --- a/transformers_cfg/recognizer.py +++ b/transformers_cfg/recognizer.py @@ -93,7 +93,7 @@ def init_stack(self, start_rule_id: int) -> Set[Tuple[int]]: element_offset = sub_rhs_offset + 1 if self.grammar_encoding[element_offset] != END_OF_ALTERNATE_MARKER: stack.append(element_offset) - stacks.update(self.advance_stack(tuple(stack))) + stacks.update(self.expand_stack_head(tuple(stack))) sub_rhs_offset += 1 + self.grammar_encoding[sub_rhs_offset] return stacks @@ -104,7 +104,15 @@ def get_termination_accept_state(self) -> AcceptState: return AcceptState(set(), PartialUTF8()) @lru_cache(maxsize=32768) - def advance_stack(self, stack: Tuple[int]) -> Set[Tuple[int]]: + def expand_stack_head(self, stack: Tuple[int]) -> Set[Tuple[int]]: + """ + Stack is the internal state of the recognizer(Pushdown Automaton). + This method updates the stack by advancing it to the next element. + If the element is a non-terminal, it expands the stack by adding the elements of the referenced rule. + A new stack is created for each alternate of the referenced rule, so we could have multiple stacks as output. + :param stack: + :return: + """ if len(stack) == 0: return {stack} @@ -137,24 +145,28 @@ def advance_stack(self, stack: Tuple[int]) -> Set[Tuple[int]]: if self.grammar_encoding[ref_element_offset] != END_OF_ALTERNATE_MARKER: new_stack.append(ref_element_offset) - new_stacks.update(self.advance_stack(tuple(new_stack))) + new_stacks.update(self.expand_stack_head(tuple(new_stack))) ref_subrule_offset += self.grammar_encoding[ref_subrule_offset] + 1 return new_stacks - def _consume_byte(self, byte: int, accept_state: AcceptState): + def _consume_byte(self, byte: int, accept_state: AcceptState) -> AcceptState: # suppose we have code point 一, ord('一') = 19968, we need to match 3 bytes # we need to match 3 bytes, so we need to call _consume_byte_partial_match 3 times - self._consume_bytes(bytes([byte]), accept_state) + return self._consume_bytes(bytes([byte]), accept_state) # @lru_cache(maxsize=32768) - def _probe_bytes( + def _try_accept_bytes( self, byte_seq: bytes, stacks: Set[Tuple[int]], partial_utf8: PartialUTF8, verbose=True, ): + """ + The difference between accept_bytes and consume_bytes is that accept_bytes returns a boolean and + consume_bytes returns a new accept state + """ if type(byte_seq) is list: byte_seq = bytes(byte_seq) code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8) @@ -162,7 +174,7 @@ def _probe_bytes( logging.debug( f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}" ) - new_stacks = self._consume_code_points(code_points, stacks) + new_stacks = self._consume_code_points_for_all_stacks(code_points, stacks) for stack in new_stacks: @@ -179,7 +191,7 @@ def _consume_bytes( byte_seq: bytes, accept_state: Optional[AcceptState] = None, verbose=True, - ): + ) -> AcceptState: if accept_state is None: accept_state = self.get_initial_accept_state() stacks = accept_state.stacks @@ -191,7 +203,7 @@ def _consume_bytes( logging.debug( f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}" ) - new_stacks = self._consume_code_points(code_points, stacks) + new_stacks = self._consume_code_points_for_all_stacks(code_points, stacks) new_new_stacks = set() for stack in new_stacks: @@ -209,8 +221,8 @@ def _consume_bytes( ########################## @lru_cache(maxsize=30000) - def _consume_code_point( - self, code_point: int, stacks: Set[Tuple[int]] + def _consume_code_point_for_all_stacks( + self, code_point: int, stacks: Tuple[Tuple[int]] ) -> Set[Tuple[int]]: """ consume a character from the stack @@ -221,11 +233,13 @@ def _consume_code_point( if code_point == 0: return new_stacks for stack in stacks: - new_stacks.update(self._consume_code_point_per_stack(code_point, stack)) + new_stacks.update( + self._consume_code_point_for_single_stack(code_point, stack) + ) return new_stacks @lru_cache(maxsize=30000) - def _consume_code_point_per_stack( + def _consume_code_point_for_single_stack( self, code_point: int, stack: Tuple[int] ) -> Set[Tuple[int]]: """ @@ -256,15 +270,22 @@ def _consume_code_point_per_stack( new_stack = list(stack[:-1]) if self.grammar_encoding[element_offset]: new_stack.append(element_offset) - return self.advance_stack(tuple(new_stack)) + # # Explicitly convert list to tuple of int to make it hashable + new_tuple_stack: Tuple[int, ...] = tuple(new_stack) + return self.expand_stack_head(new_tuple_stack) - def _consume_code_points( + def _consume_code_points_for_all_stacks( self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False ) -> Set[Tuple[int]]: + """ + code points is a list of Unicode code points. For example, the code points for "hello" is [104, 101, 108, 108, 111] + For unicode string "こんにちは世界", the code points are [12371, 12435, 12395, 12385, 12399, 19990, 30028] + + """ for i, code_point in enumerate(code_points): # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks tuple_stacks: Tuple[Tuple[int], ...] = tuple(stacks) - stacks = self._consume_code_point(code_point, tuple_stacks) + stacks = self._consume_code_point_for_all_stacks(code_point, tuple_stacks) if len(stacks) > 0 and verbose: accepted_code_point = code_points[: i + 1] corresponding_char = chr(code_point) @@ -276,7 +297,7 @@ def _consume_code_points( def _accept_code_points( self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False ) -> bool: - stacks = self._consume_code_points(code_points, stacks, verbose) + stacks = self._consume_code_points_for_all_stacks(code_points, stacks, verbose) return len(stacks) > 0 @lru_cache(maxsize=30000) @@ -295,12 +316,6 @@ def accept_code_point_at_element( return True return False - # def _accept_code_point(self, code_point: int, stacks: List[List[int]]): - # # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks - # tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks]) - # new_stacks: List[List[int]] = self._consume_code_point(code_point, tuple_stacks) - # return len(new_stacks) > 0 - ############################# # # Partial UTF-8 recognition @@ -363,7 +378,9 @@ def partial_utf8_accept_at_element( def _consume_string(self, string: str, accept_state: AcceptState): # _bytes = bytes(string, "utf-8") code_points = [ord(char) for char in string] - stacks = self._consume_code_points(code_points, accept_state.stacks) + stacks = self._consume_code_points_for_all_stacks( + code_points, accept_state.stacks + ) return AcceptState(stacks, accept_state.partial_utf8) def _accept_prefix(self, string: str, accept_state: Optional[AcceptState] = None): @@ -427,23 +444,27 @@ def char_acceptance_at_element(self, element_offset): logging.debug(acceptance) return acceptance - def _consume_code_points_new( - self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False - ) -> Set[Tuple[int]]: - new_stacks: Set[Tuple[int]] = set() - for stack in stacks: - new_stacks.update( - self._consume_code_points_per_stack(tuple(code_points), stack, verbose) - ) - return new_stacks - - @lru_cache(maxsize=30000) - def _consume_code_points_per_stack( - self, code_points: Tuple[int], stack: Tuple[int], verbose=False - ) -> Set[Tuple[int]]: - for code_point in code_points: - stacks = self._consume_code_point(code_point, (stack,)) - return stacks + # def _consume_code_points_new( + # self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False + # ) -> Set[Tuple[int]]: + # new_stacks: Set[Tuple[int]] = set() + # for stack in stacks: + # new_stacks.update( + # self._consume_code_points_per_stack(tuple(code_points), stack, verbose) + # ) + # return new_stacks + # + # @lru_cache(maxsize=30000) + # def _consume_code_points_per_stack( + # self, code_points: Tuple[int], stack: Tuple[int], verbose=False + # ) -> Set[Tuple[int]]: + # stacks = {stack} + # + # for code_point in code_points: + # # Update the stacks variable by consuming each code point. + # stacks = self._consume_code_point_for_all_stacks(code_point, (stack,)) + # + # return stacks if __name__ == "__main__": @@ -487,9 +508,9 @@ def _consume_code_points_per_stack( accept_state = AcceptState(recognizer.stacks, PartialUTF8()) for i, byte in enumerate(byte_tokens): - new_accept_state = recognizer._consume_bytes(byte, accept_state) - logging.debug(f"new partial utf8: {new_accept_state.partial_utf8}") - if len(new_accept_state.stacks) > 0: + accept_state = recognizer._consume_bytes(byte, accept_state) + logging.debug(f"new partial utf8: {accept_state.partial_utf8}") + if len(accept_state.stacks) > 0: logging.debug(f"byte {byte} is accepted") else: logging.debug(f"byte {byte} is not accepted") diff --git a/transformers_cfg/token_grammar_recognizer.py b/transformers_cfg/token_grammar_recognizer.py index 156dfd9..c23379c 100644 --- a/transformers_cfg/token_grammar_recognizer.py +++ b/transformers_cfg/token_grammar_recognizer.py @@ -68,7 +68,7 @@ def _consume_token_id( ) return accept_state - def probe_token_id(self, token_id: int, accept_state: AcceptState) -> bool: + def try_accept_token_id(self, token_id: int, accept_state: AcceptState) -> bool: stacks = accept_state.stacks if self.string_recognizer._must_stop(stacks): if token_id == self.eos_token_id: @@ -90,7 +90,7 @@ def probe_token_id(self, token_id: int, accept_state: AcceptState) -> bool: ) return len(new_acc_state.stacks) > 0 - def advance_token_ids(self, *args, **kwargs): + def consume_token_ids(self, *args, **kwargs): """Process a list of tokens according to the grammar rules.""" raise NotImplementedError @@ -132,12 +132,11 @@ def get_token_acceptance(self, accept_state, device) -> torch.Tensor: def get_token_acceptance_array_for_stack(self, stack, partial_utf8, device): # stack = list(stack) # needs to come in as a tuple for lru_cache assert isinstance(stack, tuple) - stack = list(stack) if self.byte_encoding: - accept_f = lambda x: self.string_recognizer._probe_bytes( - x, [stack], partial_utf8=partial_utf8 + accept_f = lambda x: self.string_recognizer._try_accept_bytes( + x, {stack}, partial_utf8=partial_utf8 ) token_acceptance = self.unicode_trie.get_token_acceptance( accept=accept_f, accept_eos=False, eos_token_id=self.eos_token_id @@ -173,7 +172,7 @@ def __init__(self, grammar_str, start_rule_name, tokenizer, unicode=False): # if self.last_size is not set (which would be the case when processing the first token). # In this case, do nothing. - def advance_token_ids(self, input_ids, batch_accept_states, parse_start_index=None): + def consume_token_ids(self, input_ids, batch_accept_states, parse_start_index=None): if self.last_size is None: prefix_to_parse = [ @@ -272,7 +271,7 @@ def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts) new_stack = list(stk[:-1]) if grammar.grammar_encoding[next_element_offset]: new_stack.append(next_element_offset) - new_stacks.update(grammar.advance_stack(tuple(new_stack))) + new_stacks.update(grammar.expand_stack_head(tuple(new_stack))) if new_stacks: check_token_acceptance_in_trie( @@ -293,29 +292,34 @@ def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts) tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenRecognizer = IncrementalTokenRecognizer( - grammar_str=input_text, start_rule_name="root", tokenizer=tokenizer + grammar_str=input_text, + start_rule_name="root", + tokenizer=tokenizer, + unicode=True, ) japanese = "トリーム" # "こんにちは" token_ids = tokenizer.encode(japanese) # 13298, 12675, 12045, 254 - stacks = tokenRecognizer._consume_token_ids( - token_ids, tokenRecognizer.string_recognizer.stacks, as_string=False - ) + init_state = None + state = tokenRecognizer._consume_token_ids(token_ids, init_state, as_string=False) - if stacks: + if state.stacks: print("The Japanese input is accepted") else: print("The Japanese input is not accepted") korean = "안녕하세요" token_ids = tokenizer.encode(korean) + init_state = tokenRecognizer.string_recognizer.get_initial_accept_state() try: - stacks = tokenRecognizer._consume_token_ids( - token_ids, tokenRecognizer.string_recognizer.stacks, as_string=False + state = tokenRecognizer._consume_token_ids( + token_ids, + init_state, + as_string=False, ) - if stacks: + if state.stacks: print("The Korean input is accepted") else: print("The Korean input is not accepted")