Skip to content

Commit

Permalink
Fix unicode error (#42)
Browse files Browse the repository at this point in the history
* fix unicode error

* style: rename some methods for better clarity
  • Loading branch information
Saibo-creator authored May 2, 2024
1 parent 57c3991 commit cc292d0
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 74 deletions.
12 changes: 12 additions & 0 deletions docs/contribute.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Contribute

We welcome contributions to the project.

To contribute, please follow these steps:
1. Fork the repository.
2. Create a new branch for your changes with a descriptive name, e.g. `git checkout -b feature/add-support-for-xyz` or `git checkout -b fix/parsing-error-in-abc`.
3. Create an environment with the required dependencies via `pip install -r requirements.txt`.
4. Install `pre-commit` hooks via `pre-commit install`.
5. Make your changes and add tests to ensure your changes are correct.
6. Commit them, `pre-commit` will run automatically when you commit. Tests will be run to ensure your changes are correct.
7. If all tests pass, push your changes to your fork and create a pull request.
42 changes: 29 additions & 13 deletions examples/generate_geo_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
from transformers_cfg.parser import parse_ebnf


def parse_args():
parser = argparse.ArgumentParser(description="Generate geo query strings")
parser.add_argument("--model-id", type=str, default="/dlabdata1/llm_hub/Mistral-7B-v0.1", help="Model ID")
parser.add_argument(
"--model-id",
type=str,
default="/dlabdata1/llm_hub/Mistral-7B-v0.1",
help="Model ID",
)
parser.add_argument("--device", type=str, help="Device to put the model on")
return parser.parse_args()

Expand All @@ -18,14 +24,16 @@ def main():
model_id = args.model_id

# Detect if GPU is available, otherwise use CPU
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
device = torch.device(
args.device or ("cuda" if torch.cuda.is_available() else "cpu")
)
print(f"Using device: {device}")

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Load model to defined device
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# Load grammar
with open(f"examples/grammars/geo_query.ebnf", "r") as file:
Expand All @@ -40,7 +48,7 @@ def main():
"which state contains most rivers ? ",
"number of citizens in boulder ? ",
"what are the major cities of the us ? ",
"what is the smallest city in washington ? ",
"what is the smallest city in washington ? ",
"how many states border colorado and border new mexico ? ",
]

Expand All @@ -49,7 +57,7 @@ def main():
)["input_ids"].to(
device
) # Move input_ids to the same device as model

n_examples = input_ids.shape[0]

max_new_tokens = 50
Expand All @@ -70,27 +78,35 @@ def main():
)

parsed_grammar = parse_ebnf(grammar_str)
string_grammar = StringRecognizer(parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"])

string_grammar = StringRecognizer(
parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"]
)

# decode outputs (possibly of different lengths across decoding modes)
generations = tokenizer.batch_decode(unconstrained_output, skip_special_tokens=True) + \
tokenizer.batch_decode(constrained_output, skip_special_tokens=True)
generations = tokenizer.batch_decode(
unconstrained_output, skip_special_tokens=True
) + tokenizer.batch_decode(constrained_output, skip_special_tokens=True)
print()
for i in range(n_examples):
print(f"Unconstrained: {generations[i]}")
constrained_generation = generations[i + n_examples]
print(f"Constrained: {generations[i + n_examples]}")
print(f'The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}')
print(f'The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}')
print(
f"The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}"
)
print(
f"The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}"
)
print()


if __name__ == "__main__":
main()


##########################
# Example output:
#
# Example output:
#
# Unconstrained: how many states border colorado and border new mexico ? 1.
# - How long is the drive from denver to albuquerque? The distance between Denver, Colorado (CO) & Alburqueque New Mexico(NM). Driving directions for your road trip or vacation: Get driving
# Constrained: how many states border colorado and border new mexico ? answer(smallest_one(area_1(stateid('colorado'))))
Expand Down
1 change: 0 additions & 1 deletion tests/test_parsing/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def test__parse_literal_string(self):
self.assertEqual(3, len(outbuf), f"len(outbuf): {len(outbuf)} != 3")
self.assertListEqual([2, ord("你"), ord("你")], outbuf)


def test__parse_escape(self):
escaped_char_src = '"\\n"'
outbuf = []
Expand Down
2 changes: 1 addition & 1 deletion transformers_cfg/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def process_logits(self, input_ids, scores):
)
# logger.debug("stacks: \n" + pprint.pformat(self.batch_accept_states.stacks))

self.batch_accept_states = self.grammar_constraint.advance_token_ids(
self.batch_accept_states = self.grammar_constraint.consume_token_ids(
input_ids, self.batch_accept_states, self.parse_start_index
)
logger.debug(f"input_ids: {input_ids}")
Expand Down
109 changes: 65 additions & 44 deletions transformers_cfg/recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def init_stack(self, start_rule_id: int) -> Set[Tuple[int]]:
element_offset = sub_rhs_offset + 1
if self.grammar_encoding[element_offset] != END_OF_ALTERNATE_MARKER:
stack.append(element_offset)
stacks.update(self.advance_stack(tuple(stack)))
stacks.update(self.expand_stack_head(tuple(stack)))
sub_rhs_offset += 1 + self.grammar_encoding[sub_rhs_offset]
return stacks

Expand All @@ -104,7 +104,15 @@ def get_termination_accept_state(self) -> AcceptState:
return AcceptState(set(), PartialUTF8())

@lru_cache(maxsize=32768)
def advance_stack(self, stack: Tuple[int]) -> Set[Tuple[int]]:
def expand_stack_head(self, stack: Tuple[int]) -> Set[Tuple[int]]:
"""
Stack is the internal state of the recognizer(Pushdown Automaton).
This method updates the stack by advancing it to the next element.
If the element is a non-terminal, it expands the stack by adding the elements of the referenced rule.
A new stack is created for each alternate of the referenced rule, so we could have multiple stacks as output.
:param stack:
:return:
"""
if len(stack) == 0:
return {stack}

Expand Down Expand Up @@ -137,32 +145,36 @@ def advance_stack(self, stack: Tuple[int]) -> Set[Tuple[int]]:
if self.grammar_encoding[ref_element_offset] != END_OF_ALTERNATE_MARKER:
new_stack.append(ref_element_offset)

new_stacks.update(self.advance_stack(tuple(new_stack)))
new_stacks.update(self.expand_stack_head(tuple(new_stack)))
ref_subrule_offset += self.grammar_encoding[ref_subrule_offset] + 1

return new_stacks

def _consume_byte(self, byte: int, accept_state: AcceptState):
def _consume_byte(self, byte: int, accept_state: AcceptState) -> AcceptState:
# suppose we have code point 一, ord('一') = 19968, we need to match 3 bytes
# we need to match 3 bytes, so we need to call _consume_byte_partial_match 3 times
self._consume_bytes(bytes([byte]), accept_state)
return self._consume_bytes(bytes([byte]), accept_state)

# @lru_cache(maxsize=32768)
def _probe_bytes(
def _try_accept_bytes(
self,
byte_seq: bytes,
stacks: Set[Tuple[int]],
partial_utf8: PartialUTF8,
verbose=True,
):
"""
The difference between accept_bytes and consume_bytes is that accept_bytes returns a boolean and
consume_bytes returns a new accept state
"""
if type(byte_seq) is list:
byte_seq = bytes(byte_seq)
code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8)
if verbose:
logging.debug(
f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}"
)
new_stacks = self._consume_code_points(code_points, stacks)
new_stacks = self._consume_code_points_for_all_stacks(code_points, stacks)

for stack in new_stacks:

Expand All @@ -179,7 +191,7 @@ def _consume_bytes(
byte_seq: bytes,
accept_state: Optional[AcceptState] = None,
verbose=True,
):
) -> AcceptState:
if accept_state is None:
accept_state = self.get_initial_accept_state()
stacks = accept_state.stacks
Expand All @@ -191,7 +203,7 @@ def _consume_bytes(
logging.debug(
f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}"
)
new_stacks = self._consume_code_points(code_points, stacks)
new_stacks = self._consume_code_points_for_all_stacks(code_points, stacks)

new_new_stacks = set()
for stack in new_stacks:
Expand All @@ -209,8 +221,8 @@ def _consume_bytes(
##########################

@lru_cache(maxsize=30000)
def _consume_code_point(
self, code_point: int, stacks: Set[Tuple[int]]
def _consume_code_point_for_all_stacks(
self, code_point: int, stacks: Tuple[Tuple[int]]
) -> Set[Tuple[int]]:
"""
consume a character from the stack
Expand All @@ -221,11 +233,13 @@ def _consume_code_point(
if code_point == 0:
return new_stacks
for stack in stacks:
new_stacks.update(self._consume_code_point_per_stack(code_point, stack))
new_stacks.update(
self._consume_code_point_for_single_stack(code_point, stack)
)
return new_stacks

@lru_cache(maxsize=30000)
def _consume_code_point_per_stack(
def _consume_code_point_for_single_stack(
self, code_point: int, stack: Tuple[int]
) -> Set[Tuple[int]]:
"""
Expand Down Expand Up @@ -256,15 +270,22 @@ def _consume_code_point_per_stack(
new_stack = list(stack[:-1])
if self.grammar_encoding[element_offset]:
new_stack.append(element_offset)
return self.advance_stack(tuple(new_stack))
# # Explicitly convert list to tuple of int to make it hashable
new_tuple_stack: Tuple[int, ...] = tuple(new_stack)
return self.expand_stack_head(new_tuple_stack)

def _consume_code_points(
def _consume_code_points_for_all_stacks(
self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False
) -> Set[Tuple[int]]:
"""
code points is a list of Unicode code points. For example, the code points for "hello" is [104, 101, 108, 108, 111]
For unicode string "こんにちは世界", the code points are [12371, 12435, 12395, 12385, 12399, 19990, 30028]
"""
for i, code_point in enumerate(code_points):
# for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
tuple_stacks: Tuple[Tuple[int], ...] = tuple(stacks)
stacks = self._consume_code_point(code_point, tuple_stacks)
stacks = self._consume_code_point_for_all_stacks(code_point, tuple_stacks)
if len(stacks) > 0 and verbose:
accepted_code_point = code_points[: i + 1]
corresponding_char = chr(code_point)
Expand All @@ -276,7 +297,7 @@ def _consume_code_points(
def _accept_code_points(
self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False
) -> bool:
stacks = self._consume_code_points(code_points, stacks, verbose)
stacks = self._consume_code_points_for_all_stacks(code_points, stacks, verbose)
return len(stacks) > 0

@lru_cache(maxsize=30000)
Expand All @@ -295,12 +316,6 @@ def accept_code_point_at_element(
return True
return False

# def _accept_code_point(self, code_point: int, stacks: List[List[int]]):
# # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
# tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks])
# new_stacks: List[List[int]] = self._consume_code_point(code_point, tuple_stacks)
# return len(new_stacks) > 0

#############################
#
# Partial UTF-8 recognition
Expand Down Expand Up @@ -363,7 +378,9 @@ def partial_utf8_accept_at_element(
def _consume_string(self, string: str, accept_state: AcceptState):
# _bytes = bytes(string, "utf-8")
code_points = [ord(char) for char in string]
stacks = self._consume_code_points(code_points, accept_state.stacks)
stacks = self._consume_code_points_for_all_stacks(
code_points, accept_state.stacks
)
return AcceptState(stacks, accept_state.partial_utf8)

def _accept_prefix(self, string: str, accept_state: Optional[AcceptState] = None):
Expand Down Expand Up @@ -427,23 +444,27 @@ def char_acceptance_at_element(self, element_offset):
logging.debug(acceptance)
return acceptance

def _consume_code_points_new(
self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False
) -> Set[Tuple[int]]:
new_stacks: Set[Tuple[int]] = set()
for stack in stacks:
new_stacks.update(
self._consume_code_points_per_stack(tuple(code_points), stack, verbose)
)
return new_stacks

@lru_cache(maxsize=30000)
def _consume_code_points_per_stack(
self, code_points: Tuple[int], stack: Tuple[int], verbose=False
) -> Set[Tuple[int]]:
for code_point in code_points:
stacks = self._consume_code_point(code_point, (stack,))
return stacks
# def _consume_code_points_new(
# self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False
# ) -> Set[Tuple[int]]:
# new_stacks: Set[Tuple[int]] = set()
# for stack in stacks:
# new_stacks.update(
# self._consume_code_points_per_stack(tuple(code_points), stack, verbose)
# )
# return new_stacks
#
# @lru_cache(maxsize=30000)
# def _consume_code_points_per_stack(
# self, code_points: Tuple[int], stack: Tuple[int], verbose=False
# ) -> Set[Tuple[int]]:
# stacks = {stack}
#
# for code_point in code_points:
# # Update the stacks variable by consuming each code point.
# stacks = self._consume_code_point_for_all_stacks(code_point, (stack,))
#
# return stacks


if __name__ == "__main__":
Expand Down Expand Up @@ -487,9 +508,9 @@ def _consume_code_points_per_stack(

accept_state = AcceptState(recognizer.stacks, PartialUTF8())
for i, byte in enumerate(byte_tokens):
new_accept_state = recognizer._consume_bytes(byte, accept_state)
logging.debug(f"new partial utf8: {new_accept_state.partial_utf8}")
if len(new_accept_state.stacks) > 0:
accept_state = recognizer._consume_bytes(byte, accept_state)
logging.debug(f"new partial utf8: {accept_state.partial_utf8}")
if len(accept_state.stacks) > 0:
logging.debug(f"byte {byte} is accepted")
else:
logging.debug(f"byte {byte} is not accepted")
Expand Down
Loading

0 comments on commit cc292d0

Please sign in to comment.