Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unicode error #42

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/contribute.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Contribute

We welcome contributions to the project.

To contribute, please follow these steps:
1. Fork the repository.
2. Create a new branch for your changes with a descriptive name, e.g. `git checkout -b feature/add-support-for-xyz` or `git checkout -b fix/parsing-error-in-abc`.
3. Create an environment with the required dependencies via `pip install -r requirements.txt`.
4. Install `pre-commit` hooks via `pre-commit install`.
5. Make your changes and add tests to ensure your changes are correct.
6. Commit them, `pre-commit` will run automatically when you commit. Tests will be run to ensure your changes are correct.
7. If all tests pass, push your changes to your fork and create a pull request.
42 changes: 29 additions & 13 deletions examples/generate_geo_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
from transformers_cfg.parser import parse_ebnf


def parse_args():
parser = argparse.ArgumentParser(description="Generate geo query strings")
parser.add_argument("--model-id", type=str, default="/dlabdata1/llm_hub/Mistral-7B-v0.1", help="Model ID")
parser.add_argument(
"--model-id",
type=str,
default="/dlabdata1/llm_hub/Mistral-7B-v0.1",
help="Model ID",
)
parser.add_argument("--device", type=str, help="Device to put the model on")
return parser.parse_args()

Expand All @@ -18,14 +24,16 @@ def main():
model_id = args.model_id

# Detect if GPU is available, otherwise use CPU
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
device = torch.device(
args.device or ("cuda" if torch.cuda.is_available() else "cpu")
)
print(f"Using device: {device}")

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Load model to defined device
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# Load grammar
with open(f"examples/grammars/geo_query.ebnf", "r") as file:
Expand All @@ -40,7 +48,7 @@ def main():
"which state contains most rivers ? ",
"number of citizens in boulder ? ",
"what are the major cities of the us ? ",
"what is the smallest city in washington ? ",
"what is the smallest city in washington ? ",
"how many states border colorado and border new mexico ? ",
]

Expand All @@ -49,7 +57,7 @@ def main():
)["input_ids"].to(
device
) # Move input_ids to the same device as model

n_examples = input_ids.shape[0]

max_new_tokens = 50
Expand All @@ -70,27 +78,35 @@ def main():
)

parsed_grammar = parse_ebnf(grammar_str)
string_grammar = StringRecognizer(parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"])

string_grammar = StringRecognizer(
parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"]
)

# decode outputs (possibly of different lengths across decoding modes)
generations = tokenizer.batch_decode(unconstrained_output, skip_special_tokens=True) + \
tokenizer.batch_decode(constrained_output, skip_special_tokens=True)
generations = tokenizer.batch_decode(
unconstrained_output, skip_special_tokens=True
) + tokenizer.batch_decode(constrained_output, skip_special_tokens=True)
print()
for i in range(n_examples):
print(f"Unconstrained: {generations[i]}")
constrained_generation = generations[i + n_examples]
print(f"Constrained: {generations[i + n_examples]}")
print(f'The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}')
print(f'The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}')
print(
f"The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}"
)
print(
f"The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}"
)
print()


if __name__ == "__main__":
main()


##########################
# Example output:
#
# Example output:
#
# Unconstrained: how many states border colorado and border new mexico ? 1.
# - How long is the drive from denver to albuquerque? The distance between Denver, Colorado (CO) & Alburqueque New Mexico(NM). Driving directions for your road trip or vacation: Get driving
# Constrained: how many states border colorado and border new mexico ? answer(smallest_one(area_1(stateid('colorado'))))
Expand Down
1 change: 0 additions & 1 deletion tests/test_parsing/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def test__parse_literal_string(self):
self.assertEqual(3, len(outbuf), f"len(outbuf): {len(outbuf)} != 3")
self.assertListEqual([2, ord("你"), ord("你")], outbuf)


def test__parse_escape(self):
escaped_char_src = '"\\n"'
outbuf = []
Expand Down
2 changes: 1 addition & 1 deletion transformers_cfg/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def process_logits(self, input_ids, scores):
)
# logger.debug("stacks: \n" + pprint.pformat(self.batch_accept_states.stacks))

self.batch_accept_states = self.grammar_constraint.advance_token_ids(
self.batch_accept_states = self.grammar_constraint.consume_token_ids(
input_ids, self.batch_accept_states, self.parse_start_index
)
logger.debug(f"input_ids: {input_ids}")
Expand Down
109 changes: 65 additions & 44 deletions transformers_cfg/recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def init_stack(self, start_rule_id: int) -> Set[Tuple[int]]:
element_offset = sub_rhs_offset + 1
if self.grammar_encoding[element_offset] != END_OF_ALTERNATE_MARKER:
stack.append(element_offset)
stacks.update(self.advance_stack(tuple(stack)))
stacks.update(self.expand_stack_head(tuple(stack)))
sub_rhs_offset += 1 + self.grammar_encoding[sub_rhs_offset]
return stacks

Expand All @@ -104,7 +104,15 @@ def get_termination_accept_state(self) -> AcceptState:
return AcceptState(set(), PartialUTF8())

@lru_cache(maxsize=32768)
def advance_stack(self, stack: Tuple[int]) -> Set[Tuple[int]]:
def expand_stack_head(self, stack: Tuple[int]) -> Set[Tuple[int]]:
"""
Stack is the internal state of the recognizer(Pushdown Automaton).
This method updates the stack by advancing it to the next element.
If the element is a non-terminal, it expands the stack by adding the elements of the referenced rule.
A new stack is created for each alternate of the referenced rule, so we could have multiple stacks as output.
:param stack:
:return:
"""
if len(stack) == 0:
return {stack}

Expand Down Expand Up @@ -137,32 +145,36 @@ def advance_stack(self, stack: Tuple[int]) -> Set[Tuple[int]]:
if self.grammar_encoding[ref_element_offset] != END_OF_ALTERNATE_MARKER:
new_stack.append(ref_element_offset)

new_stacks.update(self.advance_stack(tuple(new_stack)))
new_stacks.update(self.expand_stack_head(tuple(new_stack)))
ref_subrule_offset += self.grammar_encoding[ref_subrule_offset] + 1

return new_stacks

def _consume_byte(self, byte: int, accept_state: AcceptState):
def _consume_byte(self, byte: int, accept_state: AcceptState) -> AcceptState:
# suppose we have code point 一, ord('一') = 19968, we need to match 3 bytes
# we need to match 3 bytes, so we need to call _consume_byte_partial_match 3 times
self._consume_bytes(bytes([byte]), accept_state)
return self._consume_bytes(bytes([byte]), accept_state)

# @lru_cache(maxsize=32768)
def _probe_bytes(
def _try_accept_bytes(
self,
byte_seq: bytes,
stacks: Set[Tuple[int]],
partial_utf8: PartialUTF8,
verbose=True,
):
"""
The difference between accept_bytes and consume_bytes is that accept_bytes returns a boolean and
consume_bytes returns a new accept state
"""
if type(byte_seq) is list:
byte_seq = bytes(byte_seq)
code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8)
if verbose:
logging.debug(
f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}"
)
new_stacks = self._consume_code_points(code_points, stacks)
new_stacks = self._consume_code_points_for_all_stacks(code_points, stacks)

for stack in new_stacks:

Expand All @@ -179,7 +191,7 @@ def _consume_bytes(
byte_seq: bytes,
accept_state: Optional[AcceptState] = None,
verbose=True,
):
) -> AcceptState:
if accept_state is None:
accept_state = self.get_initial_accept_state()
stacks = accept_state.stacks
Expand All @@ -191,7 +203,7 @@ def _consume_bytes(
logging.debug(
f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}"
)
new_stacks = self._consume_code_points(code_points, stacks)
new_stacks = self._consume_code_points_for_all_stacks(code_points, stacks)

new_new_stacks = set()
for stack in new_stacks:
Expand All @@ -209,8 +221,8 @@ def _consume_bytes(
##########################

@lru_cache(maxsize=30000)
def _consume_code_point(
self, code_point: int, stacks: Set[Tuple[int]]
def _consume_code_point_for_all_stacks(
self, code_point: int, stacks: Tuple[Tuple[int]]
) -> Set[Tuple[int]]:
"""
consume a character from the stack
Expand All @@ -221,11 +233,13 @@ def _consume_code_point(
if code_point == 0:
return new_stacks
for stack in stacks:
new_stacks.update(self._consume_code_point_per_stack(code_point, stack))
new_stacks.update(
self._consume_code_point_for_single_stack(code_point, stack)
)
return new_stacks

@lru_cache(maxsize=30000)
def _consume_code_point_per_stack(
def _consume_code_point_for_single_stack(
self, code_point: int, stack: Tuple[int]
) -> Set[Tuple[int]]:
"""
Expand Down Expand Up @@ -256,15 +270,22 @@ def _consume_code_point_per_stack(
new_stack = list(stack[:-1])
if self.grammar_encoding[element_offset]:
new_stack.append(element_offset)
return self.advance_stack(tuple(new_stack))
# # Explicitly convert list to tuple of int to make it hashable
new_tuple_stack: Tuple[int, ...] = tuple(new_stack)
return self.expand_stack_head(new_tuple_stack)

def _consume_code_points(
def _consume_code_points_for_all_stacks(
self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False
) -> Set[Tuple[int]]:
"""
code points is a list of Unicode code points. For example, the code points for "hello" is [104, 101, 108, 108, 111]
For unicode string "こんにちは世界", the code points are [12371, 12435, 12395, 12385, 12399, 19990, 30028]

"""
for i, code_point in enumerate(code_points):
# for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
tuple_stacks: Tuple[Tuple[int], ...] = tuple(stacks)
stacks = self._consume_code_point(code_point, tuple_stacks)
stacks = self._consume_code_point_for_all_stacks(code_point, tuple_stacks)
if len(stacks) > 0 and verbose:
accepted_code_point = code_points[: i + 1]
corresponding_char = chr(code_point)
Expand All @@ -276,7 +297,7 @@ def _consume_code_points(
def _accept_code_points(
self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False
) -> bool:
stacks = self._consume_code_points(code_points, stacks, verbose)
stacks = self._consume_code_points_for_all_stacks(code_points, stacks, verbose)
return len(stacks) > 0

@lru_cache(maxsize=30000)
Expand All @@ -295,12 +316,6 @@ def accept_code_point_at_element(
return True
return False

# def _accept_code_point(self, code_point: int, stacks: List[List[int]]):
# # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
# tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks])
# new_stacks: List[List[int]] = self._consume_code_point(code_point, tuple_stacks)
# return len(new_stacks) > 0

#############################
#
# Partial UTF-8 recognition
Expand Down Expand Up @@ -363,7 +378,9 @@ def partial_utf8_accept_at_element(
def _consume_string(self, string: str, accept_state: AcceptState):
# _bytes = bytes(string, "utf-8")
code_points = [ord(char) for char in string]
stacks = self._consume_code_points(code_points, accept_state.stacks)
stacks = self._consume_code_points_for_all_stacks(
code_points, accept_state.stacks
)
return AcceptState(stacks, accept_state.partial_utf8)

def _accept_prefix(self, string: str, accept_state: Optional[AcceptState] = None):
Expand Down Expand Up @@ -427,23 +444,27 @@ def char_acceptance_at_element(self, element_offset):
logging.debug(acceptance)
return acceptance

def _consume_code_points_new(
self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False
) -> Set[Tuple[int]]:
new_stacks: Set[Tuple[int]] = set()
for stack in stacks:
new_stacks.update(
self._consume_code_points_per_stack(tuple(code_points), stack, verbose)
)
return new_stacks

@lru_cache(maxsize=30000)
def _consume_code_points_per_stack(
self, code_points: Tuple[int], stack: Tuple[int], verbose=False
) -> Set[Tuple[int]]:
for code_point in code_points:
stacks = self._consume_code_point(code_point, (stack,))
return stacks
# def _consume_code_points_new(
# self, code_points: List[int], stacks: Set[Tuple[int]], verbose=False
# ) -> Set[Tuple[int]]:
# new_stacks: Set[Tuple[int]] = set()
# for stack in stacks:
# new_stacks.update(
# self._consume_code_points_per_stack(tuple(code_points), stack, verbose)
# )
# return new_stacks
#
# @lru_cache(maxsize=30000)
# def _consume_code_points_per_stack(
# self, code_points: Tuple[int], stack: Tuple[int], verbose=False
# ) -> Set[Tuple[int]]:
# stacks = {stack}
#
# for code_point in code_points:
# # Update the stacks variable by consuming each code point.
# stacks = self._consume_code_point_for_all_stacks(code_point, (stack,))
#
# return stacks


if __name__ == "__main__":
Expand Down Expand Up @@ -487,9 +508,9 @@ def _consume_code_points_per_stack(

accept_state = AcceptState(recognizer.stacks, PartialUTF8())
for i, byte in enumerate(byte_tokens):
new_accept_state = recognizer._consume_bytes(byte, accept_state)
logging.debug(f"new partial utf8: {new_accept_state.partial_utf8}")
if len(new_accept_state.stacks) > 0:
accept_state = recognizer._consume_bytes(byte, accept_state)
logging.debug(f"new partial utf8: {accept_state.partial_utf8}")
if len(accept_state.stacks) > 0:
logging.debug(f"byte {byte} is accepted")
else:
logging.debug(f"byte {byte} is not accepted")
Expand Down
Loading