Skip to content

Commit

Permalink
WIP: LazyDeltaDiff
Browse files Browse the repository at this point in the history
This is a refactoring of #1032; the aim is to make the lazy-ness of the
DeltaDiff class more explicit.

It doesn't quite work; see a handful of test fails
  • Loading branch information
craigds committed Jan 9, 2025
1 parent 8c8dc4f commit 0b9b610
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 101 deletions.
4 changes: 2 additions & 2 deletions kart/base_diff_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,8 @@ def get_file_diff(self):

def iter_deltadiff_items(self, deltas):
if self.sort_keys:
return deltas.sorted_items()
return deltas.iter_items()
return deltas.resolve().sorted_items()
return deltas.items()

def filtered_dataset_deltas(self, ds_path, ds_diff):
"""
Expand Down
165 changes: 72 additions & 93 deletions kart/diff_structs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import UserDict
from dataclasses import dataclass
from typing import Any, Iterator
from itertools import chain
from typing import Any, Iterable, Iterator

from .exceptions import InvalidOperation

Expand Down Expand Up @@ -359,17 +360,21 @@ class Diff(RichDict):
"""

@classmethod
def concatenated(cls, *diffs, overwrite_original=False):
def concatenated(cls, *diffs):
"""
Concatenate a list of diffs, returning a new diff.
Note: This may consume/modify the diffs that were passed in for performance reasons;
it's not safe to use them after this method returns.
"""
result = None
for diff in diffs:
if diff is None:
continue
elif result is None:
result = diff
elif overwrite_original:
result += diff
else:
result = result + diff
result += diff
return result if result is not None else cls()

def __invert__(self):
Expand All @@ -383,6 +388,7 @@ def __add__(self, other: "Diff"):

# FIXME: this algorithm isn't perfect when renames are involved.

other = other.resolve()
if type(self) != type(other):
raise TypeError(f"Diff type mismatch: {type(self)} != {type(other)}")

Expand Down Expand Up @@ -432,117 +438,90 @@ def type_counts(self):
def __json__(self):
return {k: v for k, v in self.items()}


class InvalidatedDeltaDiff(Exception):
pass
def resolve(self):
"""
Returns a Diff instance (noop, for API compatibility with LazyDeltaDiff).
"""
return self


class DeltaDiff(Diff):
class LazyDeltaDiff:
"""
A DeltaDiff is the inner-most type of Diff, the one that actually contains Deltas.
Since Deltas know the keys at which they should be stored, a DeltaDiff makes sure to store Deltas at these keys.
=== Using DeltaDiff with an iterator of Deltas ===
It is possible to pass in an iterator of Deltas (e.g. a generator) to the DeltaDiff constructor,
in which case the Deltas are NOT immediately stored in the DeltaDiff.
A LazyDeltaDiff is like a DeltaDiff containing an iterator of Deltas, which is lazily evaluated.
This is useful because there may be a lot of Deltas, and we don't want to store them in memory.
The only correct way to consume a DeltaDiff populated by a generator is to call `iter_items()`,
The only correct way to consume a LazyDeltaDiff populated by a generator is to call `items()`,
which will consume the iterator as it yields Deltas.
Calling that method will invalidate the DeltaDiff, so it cannot be used again (doing so will throw an exception)
Calling that method will invalidate the LazyDeltaDiff, so it cannot be used again (doing so will throw an exception)
Calling any other dict-like method (keys(), items(), len() etc) will consume the iterator and store its
contents in memory, which may be expensive.
To consume the iterator into memory and turn the LazyDeltaDiff into a DeltaDiff, call `resolve()`
"""

child_type = Delta
_wrapped_iter: Iterator[Delta]

def __init__(self, initial_contents=()):
# An iterator over keys and Delta objects, which is consumed lazily
self._lazy_initial_contents: Iterator[tuple[str, Delta]] | None = None
if isinstance(initial_contents, (dict, UserDict)):
super().__init__(initial_contents)
def __init__(self, initial_contents: Iterable[Delta] = ()):
wrapped_iter = iter(initial_contents)
try:
first_item = next(wrapped_iter)
except StopIteration:
self._wrapped_iter = iter(())
self._bool = False
else:
if isinstance(initial_contents, Iterator):
self._lazy_initial_contents = (
(delta.key, delta) for delta in initial_contents
)
initial_contents = ()
super().__init__((delta.key, delta) for delta in initial_contents)

def __getitem__(self, key):
if key in self.data:
return self.data[key]
self._evaluate_lazy_initial_contents()
return self.data[key]

def __setitem__(self, key, delta):
if key != delta.key:
raise ValueError("Delta must be added at the appropriate key")
super().__setitem__(key, delta)

def _evaluate_lazy_initial_contents(self):
if self._lazy_initial_contents is None:
return
# Consume the generator to populate the DeltaDiff.
self.update(self._lazy_initial_contents)
self._lazy_initial_contents = None
self._wrapped_iter = chain((first_item,), wrapped_iter)
self._bool = True
self._consumed = False

def __bool__(self):
result = bool(self.data)
if (not result) and self._lazy_initial_contents:
# If the DeltaDiff is empty, but has lazy initial contents, evaluate the first item to check booleanness.
try:
k, v = next(self._lazy_initial_contents)
except StopIteration:
return False
else:
# remember this result
self.data[k] = v
return True
return result
return self._bool

def __len__(self):
self._evaluate_lazy_initial_contents()
return super().__len__()
def __add__(self, other):
resolved = self.resolve()
resolved += other
return resolved

def items(self):
self._evaluate_lazy_initial_contents()
return super().items()
def _check_not_consumed(self):
if self._consumed:
raise ValueError("LazyDeltaDiff has already been consumed")

def iter_items(self):
def items(self) -> Iterator[tuple[str, Delta]]:
"""
Iterates over the items in the DeltaDiff, including any lazy initial contents.
Iterates over the items in the LazyDeltaDiff.
This method consumes the iterator without storing its contents.
It's not safe to call this method and then consume the DeltaDiff again.
"""
yield from self.data.items()
if self._lazy_initial_contents:
for k, v in self._lazy_initial_contents:
yield (k, v)

# Invalidate this DeltaDiff; it's not safe to consume it again after this.
# `data` is the underlying contents of UserDict, which we inherit from.
# So overriding it to a non-dict will cause all dict methods to raise exceptions.
# > TypeError: argument of type 'InvalidatedDeltaDiff' is not iterable
self.data = InvalidatedDeltaDiff(
"DeltaDiff can't be used after iter_items() has been called"
)

def keys(self):
"""
Overrides the dict.keys() method to ensure we consume any lazy initial contents first
"""
self._evaluate_lazy_initial_contents()
return super().keys()
self._check_not_consumed()
self._consumed = True
for delta in self._wrapped_iter:
yield (delta.key, delta)

def values(self):
def resolve(self):
"""
Overrides the dict.values() method to ensure we consume any lazy initial contents first
Converts the LazyDeltaDiff into a DeltaDiff by consuming the wrapped iterator.
"""
self._evaluate_lazy_initial_contents()
return super().values()
self._check_not_consumed()
self._consumed = True
return DeltaDiff(self._wrapped_iter)


class DeltaDiff(Diff):
"""
A DeltaDiff is the inner-most type of Diff, the one that actually contains Deltas.
Since Deltas know the keys at which they should be stored, a DeltaDiff makes sure to store Deltas at these keys.
"""

child_type = Delta

def __init__(self, initial_contents=()):
if isinstance(initial_contents, (dict, UserDict)):
super().__init__(initial_contents)
else:
super().__init__((delta.key, delta) for delta in initial_contents)

def __setitem__(self, key, delta):
if key != delta.key:
raise ValueError("Delta must be added at the appropriate key")
super().__setitem__(key, delta)

def add_delta(self, delta):
"""Add the given delta at the appropriate key."""
Expand Down Expand Up @@ -617,7 +596,7 @@ def recursive_len(self, max_depth=None):
class DatasetDiff(Diff):
"""A DatasetDiff contains up to two DeltaDiffs, at keys "meta" or "feature"."""

child_type = (DeltaDiff, bool)
child_type = (LazyDeltaDiff, DeltaDiff, bool)

def __json__(self):
result = {}
Expand Down
4 changes: 1 addition & 3 deletions kart/diff_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def get_dataset_diff(
ds_path,
repr(target_wc_diff),
)
ds_diff = DatasetDiff.concatenated(
base_target_diff, target_wc_diff, overwrite_original=True
)
ds_diff = DatasetDiff.concatenated(base_target_diff, target_wc_diff)
if include_wc_diff:
# Get rid of parts of the diff-structure that are "empty":
ds_diff.prune()
Expand Down
1 change: 1 addition & 0 deletions kart/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def get_branch_status_message(repo):

def get_diff_status_message(diff):
"""Given a diff.Diff, return a status message describing it."""
diff = diff.resolve()
return diff_status_to_text(diff.type_counts())


Expand Down
6 changes: 4 additions & 2 deletions kart/tabular/rich_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from osgeo import osr

from kart import crs_util
from kart.diff_structs import Delta, DeltaDiff, DatasetDiff
from kart.diff_structs import Delta, DeltaDiff, DatasetDiff, LazyDeltaDiff
from kart.exceptions import PATCH_DOES_NOT_APPLY, InvalidOperation, NotYetImplemented
from kart.key_filters import DatasetKeyFilter, FeatureKeyFilter
from kart.promisor_utils import fetch_promised_blobs, object_is_promised
Expand Down Expand Up @@ -136,7 +136,9 @@ def diff(
else:
ds_diff.set_if_nonempty(
"feature",
DeltaDiff(self.diff_feature(other, feature_filter, reverse=reverse)),
LazyDeltaDiff(
self.diff_feature(other, feature_filter, reverse=reverse)
),
)
return ds_diff

Expand Down
1 change: 0 additions & 1 deletion kart/workdir.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ def _diff_to_reset(
target_datasets,
ds_filter=ds_filter,
),
overwrite_original=True,
)

tile_diff = ds_diff.get("tile")
Expand Down

0 comments on commit 0b9b610

Please sign in to comment.