Skip to content

Commit

Permalink
Improve sample generation
Browse files Browse the repository at this point in the history
  • Loading branch information
otto-ifak committed Jul 14, 2024
1 parent 8c21e55 commit b04db39
Show file tree
Hide file tree
Showing 15 changed files with 354 additions and 426 deletions.
4 changes: 3 additions & 1 deletion fences/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Some convenience imports
from .json_schema.parse import parse as parse_json_schema
from .regex.parse import parse as parse_regex
from .xml_schema.parse import parse as parse_xml_schema
from .grammar.convert import convert as parse_grammar
from .open_api.generate import parse as parse_open_api
from .open_api.generate import parse_operation
from .open_api.open_api import OpenApi
80 changes: 44 additions & 36 deletions fences/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class Node:
def __init__(self, id: Optional[str] = None) -> None:
self.id = id
self.incoming_transitions: List["IncomingTransition"] = []
self._has_valid_leafs: bool = False

def apply(self, data: any) -> any:
"""
Expand Down Expand Up @@ -126,39 +125,43 @@ def _execute(self, path: Path, path_idx: int, data: any) -> Tuple[int, any]:
idx = path[path_idx]
return self.outgoing_transitions[idx].target._execute(path, path_idx+1, data)

def _generate(self, result_path: Path, already_reached: Set):
def _generate(self, result_path: Path, already_reached: Set) -> bool:
already_reached.add(id(self))

if not isinstance(self, Decision):
return
return True
if not self.outgoing_transitions:
return
return True

satisfiable = True

if self.all_transitions:
for transition in self.outgoing_transitions:
transition._num_paths += 1
transition.target._generate(result_path, already_reached)
sub_satisfiable = transition.target._generate(result_path, already_reached)
satisfiable = satisfiable and sub_satisfiable
else:
selected = None
min_paths = float('inf')
for idx, transition in enumerate(self.outgoing_transitions):
target = transition.target
if transition._num_paths < min_paths and target._has_valid_leafs:
if transition._num_paths < min_paths and transition._satisfiable:
selected = idx
min_paths = transition._num_paths
# No satisfiable transition found, fallback to an un-satisfiable one
if selected is None:
print("No valid leaf detected, falling back to invalid one")
satisfiable = False
# print("No valid leaf detected, falling back to invalid one")
for idx, transition in enumerate(self.outgoing_transitions):
target = transition.target
if transition._num_paths < min_paths:
selected = idx
min_paths = transition._num_paths

result_path.append(selected)
transition: OutgoingTransition = self.outgoing_transitions[selected]
transition._num_paths += 1
transition.target._generate(result_path, already_reached)
sub_satisfiable = transition.target._generate(result_path, already_reached)
satisfiable = satisfiable and sub_satisfiable
return satisfiable

def _backward(self, path: Path, already_reached: Set) -> "Node":
already_reached.add(id(self))
Expand All @@ -178,42 +181,40 @@ def _backward(self, path: Path, already_reached: Set) -> "Node":
root = predecessor_transition.source._backward(path, already_reached)
return root

def _forward(self, backward_path: Path, forward_path: Path, visited: Set):
def _forward(self, backward_path: Path, forward_path: Path, visited: Set) -> bool:
if len(backward_path) == 0:
return
return True
assert isinstance(self, Decision)
path_idx = backward_path.pop(-1)
if self.all_transitions:
satisfiable = True
for idx, transition in enumerate(self.outgoing_transitions):
if idx == path_idx:
transition.target._forward(
backward_path, forward_path, visited)
s = transition.target._forward(backward_path, forward_path, visited)
else:
transition._num_paths += 1
transition.target._generate(forward_path, visited)
s = transition.target._generate(forward_path, visited)
satisfiable = satisfiable and s
else:
transition = self.outgoing_transitions[path_idx]
forward_path.append(path_idx)
transition._num_paths += 1
transition.target._forward(backward_path, forward_path, visited)
satisfiable = transition.target._forward(backward_path, forward_path, visited)

def _collect(self, visited: Set[str], valid_leafs: List["Leaf"], invalid_leafs: List["Leaf"]):
if id(self) in visited:
return
visited.add(id(self))
return satisfiable

if isinstance(self, Leaf):
if self.is_valid:
valid_leafs.append(self)
else:
invalid_leafs.append(self)
self._has_valid_leafs = self.is_valid
else:
assert isinstance(self, Decision)
self._has_valid_leafs = False
for i in self.outgoing_transitions:
i.target._collect(visited, valid_leafs, invalid_leafs)
self._has_valid_leafs = self._has_valid_leafs or i.target._has_valid_leafs
def _mark_satisfiable(self):

if isinstance(self, Decision):
if self.all_transitions:
if any(not i._satisfiable for i in self.outgoing_transitions):
return

for i in self.incoming_transitions:
out = i.outgoing_transition()
if not out._satisfiable:
out._satisfiable = True
i.source._mark_satisfiable()

def generate_paths(self) -> Generator[ResultEntry, None, None]:
"""
Expand All @@ -222,10 +223,16 @@ def generate_paths(self) -> Generator[ResultEntry, None, None]:
"""

# Reset counter, collect leafs
visited = set()
valid_nodes: List[Leaf] = []
invalid_nodes: List[Leaf] = []
self._collect(visited, valid_nodes, invalid_nodes)
for i in self.items():
if isinstance(i, Leaf):
if i.is_valid:
valid_nodes.append(i)
else:
invalid_nodes.append(i)
for leaf in valid_nodes:
leaf._mark_satisfiable()

# Visit valid nodes first
to_visit = valid_nodes + invalid_nodes
Expand All @@ -241,10 +248,10 @@ def generate_paths(self) -> Generator[ResultEntry, None, None]:

# Follow path to the target node
forward_path = []
root._forward(backward_path, forward_path, visited)
satisfiable = root._forward(backward_path, forward_path, visited)

# Yield
yield ResultEntry(next, forward_path, next.is_valid)
yield ResultEntry(next, forward_path, next.is_valid and satisfiable)

# Remove the visited nodes
path_idx = 0
Expand Down Expand Up @@ -282,6 +289,7 @@ class OutgoingTransition:
def __init__(self, target: Node) -> None:
self.target = target
self._num_paths: int = 0
self._satisfiable: bool = False


class IncomingTransition:
Expand Down
16 changes: 0 additions & 16 deletions fences/core/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,3 @@ def generate_random_number(min_value: Optional[int] = None, max_value: Optional[
max_value = +1000
assert min_value <= max_value
return random.randint(min_value, max_value)


def generate_random_format(format: str) -> str:
# From https://json-schema.org/understanding-json-schema/reference/string#built-in-formats
samples = {
"date-time": "2018-11-13T20:20:39+00:00",
"time": "20:20:39+00:00",
"date": "2018-11-13",
"duration": "P3D",
"email": "test@example.com",
"hostname": "example.com",
"ipv4": "127.0.0.1",
"ipv6": "2001:db8::8a2e:370:7334",
"uuid": "3e4666bf-d5e5-4aa7-b8ce-cefe41c7568a",
}
return samples.get(format, "")
12 changes: 12 additions & 0 deletions fences/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,15 @@ def render(self):

def print(self):
print_table(self.to_table())

def add(self, is_valid: bool, accepted: bool):
if is_valid:
if accepted:
self.valid_accepted += 1
else:
self.valid_rejected += 1
else:
if accepted:
self.invalid_accepted += 1
else:
self.invalid_rejected += 1
9 changes: 7 additions & 2 deletions fences/json_schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
from typing import Callable, Set, List, Dict

from .json_pointer import JsonPointer
from ..core.random import StringProperties

from fences.core.node import Decision

Handler = Callable[[dict, "Config", Set[str], JsonPointer], Decision]


@dataclass
class FormatSamples:
valid: List[str] = field(default_factory=list)
invalid: List[str] = field(default_factory=list)


@dataclass
class Config:
key_handlers: Dict[str, Handler]
type_handlers: Dict[str, Handler]
default_samples: Dict[str, List[any]]
normalize: bool
format_samples: Dict[str, FormatSamples]
9 changes: 6 additions & 3 deletions fences/json_schema/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def _invert_items(items: dict):
'properties': _invert_properties,
'multipleOf': lambda x: {'type': ['number'], 'NOT_multipleOf': x},
'required': lambda x: {'type': ['object'], 'properties': {i: False for i in x}},
#'required': lambda x: {'type': ['object']},
'items': _invert_items,
'minItems': lambda x: {'type': 'array', 'maxItems': x},
'maxItems': lambda x: {'type': 'array', 'minItems': x},
Expand Down Expand Up @@ -117,6 +116,9 @@ def _float_gcd(a, b, rtol = 1e-05, atol = 1e-08):
a, b = b, a % b
return a

def _ignore(_, __) -> None:
return None

_simple_mergers = {
'required': lambda a, b: list(set(a) | set(b)),
'multipleOf': lambda a, b: abs(a*b) // _float_gcd(a, b),
Expand All @@ -131,10 +133,11 @@ def _float_gcd(a, b, rtol = 1e-05, atol = 1e-08):
'maxLength': lambda a, b: min(a, b),
'enum': lambda a, b: a + b,
'format': lambda a, b: a, # todo
'deprecated': lambda a, b: a or b,
'deprecated': _ignore,
'NOT_enum': lambda a, b: a + b,
'enum': _merge_enums,
'example': lambda _, __: None
'example': _ignore,
'discriminator': _ignore,
}


Expand Down
Loading

0 comments on commit b04db39

Please sign in to comment.