From 375f85dd573915eb758e4607b156a59e8cd0dbf6 Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 4 Apr 2024 19:25:59 +0100 Subject: [PATCH] json: improved repetitions & builtin rule deps --- examples/json_schema_to_grammar.py | 107 +++++++++++++++----------- tests/test-json-schema-to-grammar.cpp | 83 +++++++++++++------- 2 files changed, 119 insertions(+), 71 deletions(-) diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 382f1baf99e5e..783f268458396 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -6,45 +6,57 @@ import sys from typing import Any, Dict, List, Set, Tuple, Union +def _build_repetition(content, up_to_n): + # return ' '.join([content] * n) + if up_to_n == 0: + return '' + return f'({content}{" " + _build_repetition(content, up_to_n-1) if up_to_n > 1 else ""})?' + +class BuiltinRule: + def __init__(self, content: str, deps: list[str] = None): + self.content = content + self.deps = deps or [] + + def __str__(self): + assert false + +_up_to_15_digits = _build_repetition('[0-9]', 15) + # whitespace is constrained to a single space char to prevent model "running away" in # whitespace. Also maybe improves generation quality? SPACE_RULE = '" "?' - + PRIMITIVE_RULES = { - 'boolean': '("true" | "false") space', - 'decimal-part': '[0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] [0-9]?)?)?)?)?)?)?)?)?)?', - 'integral-part': '[0-9] | [1-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] [0-9]?)?)?)?)?)?)?)?)?)?', - - # 'number': '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', - # 'integer': '("-"? ([0-9] | [1-9] [0-9]*)) space', - 'number': '("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', - 'integer': '("-"? integral-part) space', - 'value' : 'object | array | string | number | boolean', - 'object' : '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', - 'array' : '"[" space ( value ("," space value)* )? "]" space', - 'uuid' : '"\\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + ' "\\"" space', - 'string': r''' "\"" ( + 'boolean': BuiltinRule('("true" | "false") space', []), + 'decimal-part': BuiltinRule('[0-9] ' + _up_to_15_digits, []), + 'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []), + 'number': BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), + 'integer': BuiltinRule('("-"? integral-part) space', ['integral-part']), + 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), + 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), + 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), + 'uuid' : BuiltinRule('"\\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + ' "\\"" space', []), + 'string': BuiltinRule(r''' "\"" ( [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) - )* "\"" space''', - 'null': '"null" space', + )* "\"" space''', []), + 'null': BuiltinRule('"null" space', []), } -OBJECT_RULE_NAMES = ['object', 'array', 'string', 'integral-part', 'decimal-part', 'number', 'boolean', 'null', 'value'] # TODO: support "uri", "email" string formats -DATE_RULES = { - 'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', - 'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', - 'date-time': 'date "T" time', - 'date-string': '"\\"" date "\\"" space', - 'time-string': '"\\"" time "\\"" space', - 'date-time-string': '"\\"" date-time "\\"" space', +STRING_FORMAT_RULES = { + 'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), + 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), + 'date-time': BuiltinRule('date "T" time', ['date', 'time']), + 'date-string': BuiltinRule('"\\"" date "\\"" space', ['date']), + 'time-string': BuiltinRule('"\\"" time "\\"" space', ['time']), + 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), } DOTALL = '[\\U00000000-\\U0010FFFF]' DOT = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]' -RESERVED_NAMES = set(["root", *PRIMITIVE_RULES.keys(), *DATE_RULES.keys()]) +RESERVED_NAMES = set(["root", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') @@ -54,8 +66,6 @@ NON_LITERAL_SET = set('|.()[]{}*+?') ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?') -DATE_PATTERN = '[0-9]{4}-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])' -TIME_PATTERN = '([01][0-9]|2[0-3])(:[0-5][0-9]){2}(\\.[0-9]{1,3})?(Z|[+-](([01][0-9]|2[0-3]):[0-5][0-9]))' # Cap millisecond precision w/ 3 digits class SchemaConverter: def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): @@ -65,8 +75,6 @@ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): self._raw_pattern = raw_pattern self._rules = { 'space': SPACE_RULE, - 'integral-part': PRIMITIVE_RULES['integral-part'], - 'decimal-part': PRIMITIVE_RULES['decimal-part'], } self._refs = {} self._refs_being_resolved = set() @@ -420,7 +428,9 @@ def add_component(comp_schema, is_required): successive_items = list_item_operator * (min_items - 1) min_items -= 1 if max_items is not None and max_items > min_items: - successive_items += (list_item_operator + "?") * (max_items - min_items - 1) + # TODO: avoid grammar branch explosion here + successive_items += _build_repetition(list_item_operator, max_items - min_items - 1) + # successive_items += (list_item_operator + "?") * (max_items - min_items - 1) else: successive_items += list_item_operator + "*" if min_items == 0: @@ -433,28 +443,39 @@ def add_component(comp_schema, is_required): return self._visit_pattern(schema['pattern'], rule_name) elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): - return self._add_rule( + return self._add_primitive( 'root' if rule_name == 'root' else schema_format, PRIMITIVE_RULES['uuid'] ) - elif schema_type in (None, 'string') and schema_format in DATE_RULES: - for t, r in DATE_RULES.items(): - self._add_rule(t, r) - return schema_format + '-string' + elif schema_type in (None, 'string') and schema_format in STRING_FORMAT_RULES: + return self._add_rule(rule_name, self._add_primitive(schema_format, STRING_FORMAT_RULES[schema_format])) elif (schema_type == 'object') or (len(schema) == 0): - for n in OBJECT_RULE_NAMES: - self._add_rule(n, PRIMITIVE_RULES[n]) - return self._add_rule(rule_name, 'object') + return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) else: assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero - return self._add_rule( - 'root' if rule_name == 'root' else schema_type, - PRIMITIVE_RULES[schema_type] - ) + return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) + + def _add_primitive(self, name: str, rule: BuiltinRule): + assert isinstance(rule, BuiltinRule), f'rule: {rule}' + assert isinstance(rule.content, str), f'{name}: {rule.content}' + n = self._add_rule(name, rule.content) + + for dep in rule.deps: + dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) + assert dep_rule, f'Rule {dep} not known' + if dep not in self._rules: + self._add_primitive(dep, dep_rule) + return n + + def _build_number_rule(self): + _up_to_15_digits = _build_repetition('[0-9]', 15) + decimal_rule = self._add_rule('decimal-part', f'[0-9] {_up_to_15_digits}') + integral_rule = self._add_rule('integral-part', f'[0-9] | [1-9] {_up_to_15_digits}') + return self._add_rule('number', f'("-"? {integral_rule}) ("." {decimal_rule})? ([eE] [-+]? {integral_rule})? space') def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]): prop_order = self._prop_order @@ -476,7 +497,7 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value') prop_kv_rule_names["*"] = self._add_rule( f'{sub_name}-kv', - self._add_rule('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' + self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}' ) optional_props.append("*") diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 848b925a83247..6c71f31152f7c 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -104,16 +104,18 @@ static void test_all(const std::string & lang, std::function