Skip to content

Commit

Permalink
Merge pull request #3818 from tybug/datatree-ir
Browse files Browse the repository at this point in the history
Migrate `DataTree` to the new IR
  • Loading branch information
Zac-HD authored Feb 5, 2024
2 parents afca97e + 4d4a32f commit 30ded43
Show file tree
Hide file tree
Showing 16 changed files with 1,309 additions and 202 deletions.
5 changes: 5 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
RELEASE_TYPE: patch

This release improves our distribution of generated values for all strategies, by doing a better job of tracking which values we have generated before and avoiding generating them again.

For example, ``st.lists(st.integers())`` previously generated ~5 each of ``[]`` ``[0]`` in 100 examples. In this release, each of ``[]`` and ``[0]`` are generated ~1-2 times each.
189 changes: 149 additions & 40 deletions hypothesis-python/src/hypothesis/internal/conjecture/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
)
Expand Down Expand Up @@ -163,6 +164,8 @@ def structural_coverage(label: int) -> StructuralCoverageTag:

FLOAT_INIT_LOGIC_CACHE = LRUReusedCache(4096)

DRAW_STRING_DEFAULT_MAX_SIZE = 10**10 # "arbitrarily large"


class Example:
"""Examples track the hierarchical structure of draws from the byte stream,
Expand Down Expand Up @@ -794,6 +797,34 @@ def as_result(self) -> "_Overrun":
MAX_DEPTH = 100


class IntegerKWargs(TypedDict):
min_value: Optional[int]
max_value: Optional[int]
weights: Optional[Sequence[float]]
shrink_towards: int


class FloatKWargs(TypedDict):
min_value: float
max_value: float
allow_nan: bool
smallest_nonzero_magnitude: float


class StringKWargs(TypedDict):
intervals: IntervalSet
min_size: int
max_size: Optional[int]


class BytesKWargs(TypedDict):
size: int


class BooleanKWargs(TypedDict):
p: float


class DataObserver:
"""Observer class for recording the behaviour of a
ConjectureData object, primarily used for tracking
Expand All @@ -810,18 +841,34 @@ def conclude_test(
Note that this is called after ``freeze`` has completed.
"""

def draw_bits(self, n_bits: int, *, forced: bool, value: int) -> None:
"""Called when ``draw_bits`` is called on on the
observed ``ConjectureData``.
* ``n_bits`` is the number of bits drawn.
* ``forced`` is True if the corresponding
draw was forced or ``False`` otherwise.
* ``value`` is the result that ``draw_bits`` returned.
"""

def kill_branch(self) -> None:
"""Mark this part of the tree as not worth re-exploring."""

def draw_integer(
self, value: int, *, was_forced: bool, kwargs: IntegerKWargs
) -> None:
pass

def draw_float(
self, value: float, *, was_forced: bool, kwargs: FloatKWargs
) -> None:
pass

def draw_string(
self, value: str, *, was_forced: bool, kwargs: StringKWargs
) -> None:
pass

def draw_bytes(
self, value: bytes, *, was_forced: bool, kwargs: BytesKWargs
) -> None:
pass

def draw_boolean(
self, value: bool, *, was_forced: bool, kwargs: BooleanKWargs
) -> None:
pass


@dataclass_transform()
@attr.s(slots=True)
Expand Down Expand Up @@ -995,7 +1042,7 @@ def draw_integer(
assert min_value is not None
assert max_value is not None

sampler = Sampler(weights)
sampler = Sampler(weights, observe=False)
gap = max_value - shrink_towards

forced_idx = None
Expand Down Expand Up @@ -1023,7 +1070,7 @@ def draw_integer(
probe = shrink_towards + self._draw_unbounded_integer(
forced=None if forced is None else forced - shrink_towards
)
self._cd.stop_example(discard=max_value < probe)
self._cd.stop_example()
return probe

if max_value is None:
Expand All @@ -1034,7 +1081,7 @@ def draw_integer(
probe = shrink_towards + self._draw_unbounded_integer(
forced=None if forced is None else forced - shrink_towards
)
self._cd.stop_example(discard=probe < min_value)
self._cd.stop_example()
return probe

return self._draw_bounded_integer(
Expand Down Expand Up @@ -1091,7 +1138,7 @@ def draw_float(
assert pos_clamper is not None
clamped = pos_clamper(result)
if clamped != result and not (math.isnan(result) and allow_nan):
self._cd.stop_example(discard=True)
self._cd.stop_example()
self._cd.start_example(DRAW_FLOAT_LABEL)
self._draw_float(forced=clamped)
result = clamped
Expand All @@ -1113,7 +1160,7 @@ def draw_string(
forced: Optional[str] = None,
) -> str:
if max_size is None:
max_size = 10**10 # "arbitrarily large"
max_size = DRAW_STRING_DEFAULT_MAX_SIZE

assert forced is None or min_size <= len(forced) <= max_size

Expand All @@ -1129,6 +1176,7 @@ def draw_string(
max_size=max_size,
average_size=average_size,
forced=None if forced is None else len(forced),
observe=False,
)
while elements.more():
forced_i: Optional[int] = None
Expand Down Expand Up @@ -1264,7 +1312,7 @@ def _draw_bounded_integer(
probe = self._cd.draw_bits(
bits, forced=None if forced is None else abs(forced - center)
)
self._cd.stop_example(discard=probe > gap)
self._cd.stop_example()

if above:
result = center + probe
Expand Down Expand Up @@ -1356,7 +1404,7 @@ def permitted(f):
]
nasty_floats = [f for f in NASTY_FLOATS + boundary_values if permitted(f)]
weights = [0.2 * len(nasty_floats)] + [0.8] * len(nasty_floats)
sampler = Sampler(weights) if nasty_floats else None
sampler = Sampler(weights, observe=False) if nasty_floats else None

pos_clamper = neg_clamper = None
if sign_aware_lte(0.0, max_value):
Expand Down Expand Up @@ -1465,6 +1513,19 @@ def __repr__(self):
", frozen" if self.frozen else "",
)

# A bit of explanation of the `observe` argument in our draw_* functions.
#
# There are two types of draws: sub-ir and super-ir. For instance, some ir
# nodes use `many`, which in turn calls draw_boolean. But some strategies
# also use many, at the super-ir level. We don't want to write sub-ir draws
# to the DataTree (and consequently use them when computing novel prefixes),
# since they are fully recorded by writing the ir node itself.
# But super-ir draws are not included in the ir node, so we do want to write
# these to the tree.
#
# `observe` formalizes this distinction. The draw will only be written to
# the DataTree if observe is True.

def draw_integer(
self,
min_value: Optional[int] = None,
Expand All @@ -1474,6 +1535,7 @@ def draw_integer(
weights: Optional[Sequence[float]] = None,
shrink_towards: int = 0,
forced: Optional[int] = None,
observe: bool = True,
) -> int:
# Validate arguments
if weights is not None:
Expand All @@ -1494,13 +1556,18 @@ def draw_integer(
if forced is not None and max_value is not None:
assert forced <= max_value

return self.provider.draw_integer(
min_value=min_value,
max_value=max_value,
weights=weights,
shrink_towards=shrink_towards,
forced=forced,
)
kwargs: IntegerKWargs = {
"min_value": min_value,
"max_value": max_value,
"weights": weights,
"shrink_towards": shrink_towards,
}
value = self.provider.draw_integer(**kwargs, forced=forced)
if observe:
self.observer.draw_integer(
value, was_forced=forced is not None, kwargs=kwargs
)
return value

def draw_float(
self,
Expand All @@ -1514,6 +1581,7 @@ def draw_float(
# width: Literal[16, 32, 64] = 64,
# exclude_min and exclude_max handled higher up,
forced: Optional[float] = None,
observe: bool = True,
) -> float:
assert smallest_nonzero_magnitude > 0
assert not math.isnan(min_value)
Expand All @@ -1523,13 +1591,18 @@ def draw_float(
assert allow_nan or not math.isnan(forced)
assert math.isnan(forced) or min_value <= forced <= max_value

return self.provider.draw_float(
min_value=min_value,
max_value=max_value,
allow_nan=allow_nan,
smallest_nonzero_magnitude=smallest_nonzero_magnitude,
forced=forced,
)
kwargs: FloatKWargs = {
"min_value": min_value,
"max_value": max_value,
"allow_nan": allow_nan,
"smallest_nonzero_magnitude": smallest_nonzero_magnitude,
}
value = self.provider.draw_float(**kwargs, forced=forced)
if observe:
self.observer.draw_float(
value, kwargs=kwargs, was_forced=forced is not None
)
return value

def draw_string(
self,
Expand All @@ -1538,19 +1611,44 @@ def draw_string(
min_size: int = 0,
max_size: Optional[int] = None,
forced: Optional[str] = None,
observe: bool = True,
) -> str:
assert forced is None or min_size <= len(forced)
return self.provider.draw_string(
intervals, min_size=min_size, max_size=max_size, forced=forced
)

def draw_bytes(self, size: int, *, forced: Optional[bytes] = None) -> bytes:
kwargs: StringKWargs = {
"intervals": intervals,
"min_size": min_size,
"max_size": max_size,
}
value = self.provider.draw_string(**kwargs, forced=forced)
if observe:
self.observer.draw_string(
value, kwargs=kwargs, was_forced=forced is not None
)
return value

def draw_bytes(
self,
# TODO move to min_size and max_size here.
size: int,
*,
forced: Optional[bytes] = None,
observe: bool = True,
) -> bytes:
assert forced is None or len(forced) == size
assert size >= 0

return self.provider.draw_bytes(size, forced=forced)
kwargs: BytesKWargs = {"size": size}
value = self.provider.draw_bytes(**kwargs, forced=forced)
if observe:
self.observer.draw_bytes(
value, kwargs=kwargs, was_forced=forced is not None
)
return value

def draw_boolean(self, p: float = 0.5, *, forced: Optional[bool] = None) -> bool:
def draw_boolean(
self, p: float = 0.5, *, forced: Optional[bool] = None, observe: bool = True
) -> bool:
# Internally, we treat probabilities lower than 1 / 2**64 as
# unconditionally false.
#
Expand All @@ -1561,7 +1659,13 @@ def draw_boolean(self, p: float = 0.5, *, forced: Optional[bool] = None) -> bool
if forced is False:
assert p < (1 - 2 ** (-64))

return self.provider.draw_boolean(p, forced=forced)
kwargs: BooleanKWargs = {"p": p}
value = self.provider.draw_boolean(**kwargs, forced=forced)
if observe:
self.observer.draw_boolean(
value, kwargs=kwargs, was_forced=forced is not None
)
return value

def as_result(self) -> Union[ConjectureResult, _Overrun]:
"""Convert the result of running this test into
Expand Down Expand Up @@ -1735,9 +1839,15 @@ def freeze(self) -> None:
self.buffer = bytes(self.buffer)
self.observer.conclude_test(self.status, self.interesting_origin)

def choice(self, values: Sequence[T], *, forced: Optional[T] = None) -> T:
def choice(
self,
values: Sequence[T],
*,
forced: Optional[T] = None,
observe: bool = True,
) -> T:
forced_i = None if forced is None else values.index(forced)
i = self.draw_integer(0, len(values) - 1, forced=forced_i)
i = self.draw_integer(0, len(values) - 1, forced=forced_i, observe=observe)
return values[i]

def draw_bits(self, n: int, *, forced: Optional[int] = None) -> int:
Expand Down Expand Up @@ -1774,7 +1884,6 @@ def draw_bits(self, n: int, *, forced: Optional[int] = None) -> int:
buf = bytes(buf)
result = int_from_bytes(buf)

self.observer.draw_bits(n, forced=forced is not None, value=result)
self.__example_record.draw_bits(n, forced)

initial = self.index
Expand Down
Loading

0 comments on commit 30ded43

Please sign in to comment.