diff --git a/tests/test_parsing/test_parsing.py b/tests/test_parsing/test_parsing.py index 1a8d669..aa35f32 100644 --- a/tests/test_parsing/test_parsing.py +++ b/tests/test_parsing/test_parsing.py @@ -138,6 +138,31 @@ def test__parse_literal_string(self): self.assertEqual(3, len(outbuf), f"len(outbuf): {len(outbuf)} != 3") self.assertListEqual([2, ord("你"), ord("你")], outbuf) + + def test__parse_escape(self): + escaped_char_src = '"\\n"' + outbuf = [] + + remaining_src = _parse_rhs_literal_string(escaped_char_src, outbuf) + self.assertEqual(3, len(outbuf), f"len(outbuf): {len(outbuf)} != 3") + self.assertListEqual([2, ord("\n"), ord("\n")], outbuf) + + escaped_backslash_src = '"\\\\"' + outbuf = [] + + remaining_src = _parse_rhs_literal_string(escaped_backslash_src, outbuf) + self.assertEqual(3, len(outbuf), f"len(outbuf): {len(outbuf)} != 3") + self.assertListEqual([2, ord("\\"), ord("\\")], outbuf) + self.assertEqual("", remaining_src, f"remaining_src: {remaining_src} != ''") + + escaped_backslash_src = '"\\x5C"' + outbuf = [] + + remaining_src = _parse_rhs_literal_string(escaped_backslash_src, outbuf) + self.assertEqual(3, len(outbuf), f"len(outbuf): {len(outbuf)} != 3") + self.assertListEqual([2, ord("\\"), ord("\\")], outbuf) + self.assertEqual("", remaining_src, f"remaining_src: {remaining_src} != ''") + def test_null(self): src = "root ::= " rhs_src = "" @@ -296,6 +321,15 @@ def test_parse_rhs(self): ) logging.debug(f"state.grammar_encoding of {rhs_src}: {state.grammar_encoding}") + src = 'root ::= "\\\\"' + rhs_src = '"\\\\"' + state = ParseState() + state.symbol_table["root"] = 9 + _ = parse_rhs( + state=state, rhs=rhs_src, rule_name="root", rule_id=9, is_nested=False + ) + logging.debug(f"state.grammar_encoding of {rhs_src}: {state.grammar_encoding}") + def test__parse_symbol_reference(self): state = ParseState() outbuf = [] diff --git a/transformers_cfg/parser.py b/transformers_cfg/parser.py index 26a9eb1..abfa25e 100644 --- a/transformers_cfg/parser.py +++ b/transformers_cfg/parser.py @@ -127,7 +127,7 @@ def parse_char(src) -> (str, str): if first > -1: second = hex_to_int(src[3]) if second > -1: - return (first << 4) + second, src[4:] + return chr((first << 4) + second), src[4:] raise RuntimeError("expecting \\xNN at " + src) elif esc in ('"', "[", "]"): return esc, src[2:]