diff --git a/docs/source/user_guide/using_views.rst b/docs/source/user_guide/using_views.rst index d141ba94247..65f9e78ccdb 100644 --- a/docs/source/user_guide/using_views.rst +++ b/docs/source/user_guide/using_views.rst @@ -1006,6 +1006,143 @@ contains duplicate sample IDs: The :ref:`FiftyOne App ` is not designed to display views with duplicate sample IDs. +.. _materializing-views: + +Materializing views +___________________ + +You can use +:meth:`materialize() ` +to cache any view as a temporary database collection. + +Materialized views are identical in content and function to their source views, +but since their contents have been precomputed and cached in the database, +computing/iterating over them is as fast as possible. + +Applying this stage to an expensive view like an unindexed filtering/sorting +operation on a large dataset can be helpful if you plan to perform multiple +downstream operations on the view. + +.. code-block:: python + :linenos: + + import fiftyone as fo + import fiftyone.zoo as foz + from fiftyone import ViewField as F + + dataset = foz.load_zoo_dataset("quickstart") + + # A complex view that involves filtering/sorting + view = ( + dataset + .filter_labels("ground_truth", F("label") == "person") + .sort_by(F("ground_truth.detections.label").length(), reverse=True) + .limit(10) + ) + + # Materialize the view + materialized_view = view.materialize() + + # The materialized view's contents are identical to the input view + assert len(view) == len(materialized_view) + + # but it doesn't require aggregation to serve its contents + assert view._pipeline() != [] + assert materialized_view._pipeline() == [] + + # so downstream tasks like statistics are more efficient + print(view.count_values("ground_truth.detections.label")) + print(materialized_view.count_values("ground_truth.detections.label")) + # {'person': 140} + # {'person': 140} + +.. note:: + + Like :ref:`non-persistent datasets `, the temporary + collections created by materializing views are automatically deleted when + all active FiftyOne sessions are closed. + +Materialized views are fully-compliant with the |DatasetView| interface, so you +are free to chain additional view stages as usual: + +.. code-block:: python + :linenos: + + complex_view = materialized_view.to_patches("ground_truth").limit(50) + print(complex_view.count_values("ground_truth.label")) + # {'person': 50} + +You can even store them as :ref:`saved views ` on your dataset +so you can load them later: + +.. code-block:: python + :linenos: + + dataset.save_view("most_people", materialized_view) + + # later in another session... + + same_materialized_view = dataset.load_saved_view("most_people") + + print(same_materialized_view.count_values("ground_truth.detections.label")) + # {'person': 140} + +.. note:: + + When you load a saved materialized view, if the view's cached data has been + deleted, the collection(s) will be automatically regenerated from the + underlying dataset when the saved materialized view is loaded. + +Note that, since materialized views are served from a cache, their contents +are not immediately affected if the source dataset is edited. To pull in +changes, you can call +:meth:`reload() ` to +rehydrate an existing view, or you can use +:meth:`materialize() ` +to create a new materalized view: + +.. code-block:: python + :linenos: + + # Delete the 10 samples with the most people from the dataset + dataset.delete_samples(view[:10]) + + # Existing materialized view does not yet reflect the changes + print(view.count_values("ground_truth.detections.label")) + print(materialized_view.count_values("ground_truth.detections.label")) + # {'person': 101} + # {'person': 140} + + # Option 1: reload an existing materialized view + materialized_view.reload() + + # Option 2: create a new materialized view + also_materialized_view = view.materialize() + + print(materialized_view.count_values("ground_truth.detections.label")) + print(also_materialized_view.count_values("ground_truth.detections.label")) + # {'person': 101} + # {'person': 101} + +Materialized views also behave just like any other views in the sense that: + +- Any modifications to the samples that you make by iterating over the + contents of the materialized view or calling methods like + :meth:`set_values() ` + will be reflected on the source dataset +- Any modifications to sample or label tags that you make via the App's + :ref:`tagging menu ` or via API methods like + :meth:`tag_samples() ` + and + :meth:`tag_labels() ` + will be reflected on the source dataset +- Calling other methods like + :meth:`save() `, + :meth:`keep() `, and + :meth:`keep_fields() ` + on the materialized view that edit the underlying dataset's contents will + autoatically be reflected on the source dataset + .. _date-views: Date-based views diff --git a/fiftyone/__public__.py b/fiftyone/__public__.py index 1ccd9d15a51..316b88c6962 100644 --- a/fiftyone/__public__.py +++ b/fiftyone/__public__.py @@ -214,6 +214,7 @@ MatchFrames, MatchLabels, MatchTags, + Materialize, Mongo, Select, SelectBy, diff --git a/fiftyone/core/clips.py b/fiftyone/core/clips.py index b05cdb4c25b..1990be274bf 100644 --- a/fiftyone/core/clips.py +++ b/fiftyone/core/clips.py @@ -9,6 +9,7 @@ from copy import deepcopy from bson import ObjectId +from pymongo import UpdateOne, UpdateMany import eta.core.utils as etau @@ -47,10 +48,13 @@ def _sample_id(self): return ObjectId(self._doc.sample_id) def _save(self, deferred=False): - sample_ops, frame_ops = super()._save(deferred=deferred) + sample_ops, frame_ops = super()._save(deferred=True) if not deferred: - self._view._sync_source_sample(self) + self._view._save_sample( + self, sample_ops=sample_ops, frame_ops=frame_ops + ) + return None, [] return sample_ops, frame_ops @@ -334,14 +338,73 @@ def reload(self): super().reload() - def _sync_source_sample(self, sample): - if not self._classification_field: + def _check_for_field_edits(self, ops, fields): + updated_fields = set() + + for op in ops: + if isinstance(op, (UpdateOne, UpdateMany)): + updated_fields.update(op._doc.get("$set", {}).keys()) + updated_fields.update(op._doc.get("$unset", {}).keys()) + + for field in list(updated_fields): + chunks = field.split(".") + for i in range(1, len(chunks)): + updated_fields.add(".".join(chunks[:i])) + + return bool(updated_fields & set(fields)) + + def _bulk_write( + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, + ): + self._clips_dataset._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) + + # Clips views directly use their source collection's frames, so there's + # no need to sync + if frames: return - # Sync label + support to underlying TemporalDetection + field = self._classification_field + if field is not None and self._check_for_field_edits(ops, [field]): + self._sync_source(fields=[field], ids=ids) + self._source_collection._dataset._reload_docs(ids=ids) + + def _save_sample(self, sample, sample_ops=None, frame_ops=None): + if sample_ops: + foo.bulk_write(sample_ops, self._clips_dataset._sample_collection) + + if frame_ops: + foo.bulk_write(frame_ops, self._clips_dataset._frame_collection) + self._sync_source_sample( + sample, sample_ops=sample_ops, frame_ops=frame_ops + ) + + def _sync_source_sample(self, sample, sample_ops=None, frame_ops=None): field = self._classification_field + if not field: + return + + if sample_ops is not None and not self._check_for_field_edits( + sample_ops, [field] + ): + return + + # Sync label + support to underlying TemporalDetection + classification = sample[field] if classification is not None: doc = classification.to_dict() @@ -353,10 +416,9 @@ def _sync_source_sample(self, sample): self._source_collection._set_labels(field, [sample.sample_id], [doc]) def _sync_source(self, fields=None, ids=None, update=True, delete=False): - if not self._classification_field: - return - field = self._classification_field + if not field: + return if fields is not None and field not in fields: return diff --git a/fiftyone/core/collections.py b/fiftyone/core/collections.py index 1cbc54d1274..fb3d12d0a97 100644 --- a/fiftyone/core/collections.py +++ b/fiftyone/core/collections.py @@ -113,14 +113,9 @@ def __init__( self.batch_size = batch_size self._dataset = sample_collection._dataset - self._sample_coll = sample_collection._dataset._sample_collection - self._frame_coll = sample_collection._dataset._frame_collection - self._is_generated = sample_collection._is_generated - self._sample_ops = [] self._frame_ops = [] self._batch_ids = [] - self._reload_parents = [] self._batching_strategy = batching_strategy self._curr_batch_size = None @@ -154,7 +149,6 @@ def save(self, sample): ) sample_ops, frame_ops = sample._save(deferred=True) - updated = sample_ops or frame_ops if sample_ops: self._sample_ops.extend(sample_ops) @@ -162,12 +156,9 @@ def save(self, sample): if frame_ops: self._frame_ops.extend(frame_ops) - if updated and self._is_generated: + if sample_ops or frame_ops: self._batch_ids.append(sample.id) - if updated and isinstance(sample, fosa.SampleView): - self._reload_parents.append(sample) - if self._batching_strategy == "static": self._curr_batch_size += 1 if self._curr_batch_size >= self.batch_size: @@ -194,22 +185,23 @@ def save(self, sample): def _save_batch(self): if self._sample_ops: - foo.bulk_write(self._sample_ops, self._sample_coll, ordered=False) + self.sample_collection._bulk_write( + self._sample_ops, + ids=self._batch_ids, + ordered=False, + ) self._sample_ops.clear() if self._frame_ops: - foo.bulk_write(self._frame_ops, self._frame_coll, ordered=False) + self.sample_collection._bulk_write( + self._frame_ops, + sample_ids=self._batch_ids, + frames=True, + ordered=False, + ) self._frame_ops.clear() - if self._batch_ids and self._is_generated: - self.sample_collection._sync_source(ids=self._batch_ids) - self._batch_ids.clear() - - if self._reload_parents: - for sample in self._reload_parents: - sample._reload_parents() - - self._reload_parents.clear() + self._batch_ids.clear() class SampleCollection(object): @@ -291,6 +283,11 @@ def _is_clips(self): """Whether this collection contains clips.""" raise NotImplementedError("Subclass must implement _is_clips") + @property + def _is_materialized(self): + """Whether this collection contains a materialized view.""" + raise NotImplementedError("Subclass must implement _is_materialized") + @property def _is_dynamic_groups(self): """Whether this collection contains dynamic groups.""" @@ -1899,10 +1896,15 @@ def untag_samples(self, tags): view = self.match_tags(tags) view._edit_sample_tags(update) - def _edit_sample_tags(self, update): + def _edit_sample_tags(self, update, ids=None): if self._is_read_only_field("tags"): raise ValueError("Cannot edit read-only field 'tags'") + if ids is None: + _ids = self.values("_id") + else: + _ids = [ObjectId(_id) for _id in ids] + update["$set"] = {"last_modified_at": datetime.utcnow()} ids = [] @@ -1910,13 +1912,15 @@ def _edit_sample_tags(self, update): batch_size = fou.recommend_batch_size_for_value( ObjectId(), max_size=100000 ) - for _ids in fou.iter_batches(self.values("_id"), batch_size): - ids.extend(_ids) - ops.append(UpdateMany({"_id": {"$in": _ids}}, update)) + for _batch_ids in fou.iter_batches(_ids, batch_size): + ids.extend(_batch_ids) + ops.append(UpdateMany({"_id": {"$in": _batch_ids}}, update)) if ops: self._dataset._bulk_write(ops, ids=ids) + return ids + def count_sample_tags(self): """Counts the occurrences of sample tags in this collection. @@ -2049,8 +2053,8 @@ def _edit_label_tags( if ids is None or label_ids is None: if is_frame_field: ids, label_ids = self.values(["frames._id", id_path]) - ids = itertools.chain.from_iterable(ids) - label_ids = itertools.chain.from_iterable(label_ids) + ids = list(itertools.chain.from_iterable(ids)) + label_ids = list(itertools.chain.from_iterable(label_ids)) else: ids, label_ids = self.values(["_id", id_path]) @@ -3212,6 +3216,24 @@ def _set_labels(self, field_name, sample_ids, label_docs, progress=False): def _delete_labels(self, ids, fields=None): self._dataset.delete_labels(ids=ids, fields=fields) + def _bulk_write( + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, + ): + self._dataset._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) + def compute_metadata( self, overwrite=False, @@ -6193,6 +6215,33 @@ def match_tags(self, tags, bool=None, all=False): """ return self._add_view_stage(fos.MatchTags(tags, bool=bool, all=all)) + @view_stage + def materialize(self): + """Materializes the current view into a temporary database collection. + + Apply this stage to an expensive view (eg an unindexed filtering + operation on a large dataset) if you plan to perform multiple + downstream operations on the view. + + Examples:: + + import fiftyone as fo + import fiftyone.zoo as foz + from fiftyone import ViewField as F + + dataset = foz.load_zoo_dataset("quickstart") + + view = dataset.filter_labels("ground_truth", F("label") == "cat") + materialized_view = view.materialize() + + print(view.count("ground_truth.detections")) + print(materialized_view.count("ground_truth.detections")) + + Returns: + a :class:`fiftyone.core.view.DatasetView` + """ + return self._add_view_stage(fos.Materialize()) + @view_stage def mongo(self, pipeline, _needs_frames=None, _group_slices=None): """Adds a view stage defined by a raw MongoDB aggregation pipeline. diff --git a/fiftyone/core/dataset.py b/fiftyone/core/dataset.py index e69d17ef3a6..e7ab8b63133 100644 --- a/fiftyone/core/dataset.py +++ b/fiftyone/core/dataset.py @@ -426,7 +426,12 @@ def _root_dataset(self): @property def _is_generated(self): - return self._is_patches or self._is_frames or self._is_clips + return ( + self._is_patches + or self._is_frames + or self._is_clips + or self._is_materialized + ) @property def _is_patches(self): @@ -442,6 +447,10 @@ def _is_frames(self): def _is_clips(self): return self._sample_collection_name.startswith("clips.") + @property + def _is_materialized(self): + return self._sample_collection_name.startswith("materialized.") + @property def _is_dynamic_groups(self): return False @@ -3456,7 +3465,13 @@ def _make_dict( return d def _bulk_write( - self, ops, ids=None, frames=False, ordered=False, progress=False + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, ): if frames: coll = self._frame_collection @@ -3466,7 +3481,11 @@ def _bulk_write( foo.bulk_write(ops, coll, ordered=ordered, progress=progress) if frames: - fofr.Frame._reload_docs(self._frame_collection_name, frame_ids=ids) + fofr.Frame._reload_docs( + self._frame_collection_name, + sample_ids=sample_ids, + frame_ids=ids, + ) else: fos.Sample._reload_docs( self._sample_collection_name, sample_ids=ids @@ -4877,7 +4896,13 @@ def clone(self, name=None, persistent=False): """ return self._clone(name=name, persistent=persistent) - def _clone(self, name=None, persistent=False, view=None): + def _clone( + self, + name=None, + persistent=False, + view=None, + materialized=False, + ): if name is None: name = get_default_dataset_name() @@ -4886,7 +4911,12 @@ def _clone(self, name=None, persistent=False, view=None): else: sample_collection = self - return _clone_collection(sample_collection, name, persistent) + return _clone_collection( + sample_collection, + name, + persistent=persistent, + materialized=materialized, + ) def clear(self): """Removes all samples from the dataset. @@ -5056,7 +5086,7 @@ def _clear_frames(self, view=None, sample_ids=None, frame_ids=None): self._frame_collection_name, sample_ids=sample_ids ) - def _keep_frames(self, view=None, frame_ids=None): + def _keep_frames(self, view=None): sample_collection = view if view is not None else self if not sample_collection._contains_videos(any_slice=True): return @@ -7981,6 +8011,7 @@ def reload(self): """Reloads the dataset and any in-memory samples from the database.""" self._reload(hard=True) self._reload_docs(hard=True) + self._reload_docs(frames=True, hard=True) def clear_cache(self): """Clears the dataset's in-memory cache. @@ -8022,11 +8053,18 @@ def _reload(self, hard=False): self._update_last_loaded_at() - def _reload_docs(self, hard=False): - fos.Sample._reload_docs(self._sample_collection_name, hard=hard) + def _reload_docs(self, ids=None, frames=False, hard=False): + if frames: + if not self._has_frame_fields(): + return - if self._has_frame_fields(): - fofr.Frame._reload_docs(self._frame_collection_name, hard=hard) + fofr.Frame._reload_docs( + self._frame_collection_name, frame_ids=ids, hard=hard + ) + else: + fos.Sample._reload_docs( + self._sample_collection_name, sample_ids=ids, hard=hard + ) def _serialize(self): return self._doc.to_dict(extended=True) @@ -8271,7 +8309,7 @@ def _clone_collection_indexes( def _make_sample_collection_name( - dataset_id, patches=False, frames=False, clips=False + dataset_id, patches=False, frames=False, clips=False, materialized=False ): if patches and frames: prefix = "patches.frames" @@ -8281,6 +8319,8 @@ def _make_sample_collection_name( prefix = "frames" elif clips: prefix = "clips" + elif materialized: + prefix = "materialized" else: prefix = "samples" @@ -8463,7 +8503,12 @@ def _delete_dataset_doc(dataset_doc): dataset_doc.delete() -def _clone_collection(sample_collection, name, persistent): +def _clone_collection( + sample_collection, + name, + persistent=False, + materialized=False, +): slug = _validate_dataset_name(name) contains_videos = sample_collection._contains_videos(any_slice=True) @@ -8490,7 +8535,9 @@ def _clone_collection(sample_collection, name, persistent): _id = dataset_doc.id now = datetime.utcnow() - sample_collection_name = _make_sample_collection_name(_id) + sample_collection_name = _make_sample_collection_name( + _id, materialized=materialized + ) if contains_videos: frame_collection_name = _make_frame_collection_name( diff --git a/fiftyone/core/materialize.py b/fiftyone/core/materialize.py new file mode 100644 index 00000000000..5af477c3f0d --- /dev/null +++ b/fiftyone/core/materialize.py @@ -0,0 +1,607 @@ +""" +Materialized views. + +| Copyright 2017-2024, Voxel51, Inc. +| `voxel51.com `_ +| +""" +from copy import deepcopy + +from bson import ObjectId + +import eta.core.utils as etau + +import fiftyone.core.media as fom +import fiftyone.core.sample as fos +import fiftyone.core.odm as foo +import fiftyone.core.utils as fou +import fiftyone.core.view as fov + + +class MaterializedSampleView(fos.SampleView): + """A sample in a :class:`MaterializedView`. + + :class:`MaterializedSampleView` instances should not be created manually; + they are generated by iterating over :class:`MaterializedView` instances. + + Args: + doc: a :class:`fiftyone.core.odm.DatasetSampleDocument` + view: the :class:`MaterializedView` that the sample belongs to + selected_fields (None): a set of field names that this view is + restricted to + excluded_fields (None): a set of field names that are excluded from + this view + filtered_fields (None): a set of field names of list fields that are + filtered in this view + """ + + def _save(self, deferred=False): + sample_ops, frame_ops = super()._save(deferred=True) + + if not deferred: + self._view._save_sample( + self, sample_ops=sample_ops, frame_ops=frame_ops + ) + return None, [] + + return sample_ops, frame_ops + + +class MaterializedView(fov.DatasetView): + """A :class:`fiftyone.core.view.DatasetView` of samples from a materialized + view. + + Samples retrieved from materialized views are returned as + :class:`MaterializedSampleView` objects. + + Args: + source_collection: the + :class:`fiftyone.core.collections.SampleCollection` from which this + view was created + materialize_stage: the :class:`fiftyone.core.stages.Materialize` stage + that created this view + materialized_dataset: the :class:`fiftyone.core.dataset.Dataset` that + serves the samples in this view + """ + + __slots__ = ( + "_source_collection", + "_materialize_stage", + "_materialized_dataset", + "__stages", + "__media_type", + "__name", + ) + + def __init__( + self, + source_collection, + materialize_stage, + materialized_dataset, + _stages=None, + _media_type=None, + _name=None, + ): + if _stages is None: + _stages = [] + + self._source_collection = source_collection + self._materialize_stage = materialize_stage + self._materialized_dataset = materialized_dataset + self.__stages = _stages + self.__media_type = _media_type + self.__name = _name + + def __copy__(self): + return self.__class__( + self._source_collection, + deepcopy(self._materialize_stage), + self._materialized_dataset, + _stages=deepcopy(self.__stages), + _media_type=self.__media_type, + _name=self.__name, + ) + + @property + def _base_view(self): + return self.__class__( + self._source_collection, + self._materialize_stage, + self._materialized_dataset, + ) + + @property + def _dataset(self): + return self._materialized_dataset + + @property + def _root_dataset(self): + return self._source_collection._root_dataset + + @property + def _sample_cls(self): + return MaterializedSampleView + + @property + def _stages(self): + return self.__stages + + @property + def _all_stages(self): + return ( + self._source_collection.view()._all_stages + + [self._materialize_stage] + + self.__stages + ) + + @property + def name(self): + return self.__name + + @property + def is_saved(self): + return self.__name is not None + + @property + def media_type(self): + if self.__media_type is not None: + return self.__media_type + + return self._dataset.media_type + + def _set_name(self, name): + self.__name = name + + def _set_media_type(self, media_type): + self.__media_type = media_type + + def _edit_sample_tags(self, update, ids=None): + ids = super()._edit_sample_tags(update, ids=ids) + + self._source_collection._edit_sample_tags(update, ids=ids) + + def _edit_label_tags( + self, update_fcn, label_field, ids=None, label_ids=None + ): + ids, label_ids = super()._edit_label_tags( + update_fcn, label_field, ids=ids, label_ids=label_ids + ) + + self._source_collection._edit_label_tags( + update_fcn, label_field, ids=ids, label_ids=label_ids + ) + + def set_values(self, field_name, *args, **kwargs): + # The `set_values()` operation could change the contents of this view, + # so we first record the sample IDs that need to be synced + if self._stages: + ids = self.values("id") + else: + ids = None + + super().set_values(field_name, *args, **kwargs) + + field = field_name.split(".", 1)[0] + self._sync_source(fields=[field], ids=ids) + self._sync_source_field_schema(field_name) + + def set_label_values(self, field_name, *args, **kwargs): + super().set_label_values(field_name, *args, **kwargs) + + self._source_collection.set_label_values(field_name, *args, **kwargs) + + def save(self, fields=None): + """Saves the samples in this view to the underlying dataset. + + .. note:: + + This method is not a :class:`fiftyone.core.stages.ViewStage`; + it immediately writes the requested changes to the underlying + dataset. + + .. warning:: + + This will permanently delete any omitted or filtered contents from + the samples of the source dataset. + + Args: + fields (None): an optional field or list of fields to save. If + specified, only these fields are overwritten + """ + if etau.is_str(fields): + fields = [fields] + + super().save(fields=fields) + + self._sync_source(fields=fields) + + def keep(self): + """Deletes all samples that are **not** in this view from the underlying + dataset. + + .. note:: + + This method is not a :class:`fiftyone.core.stages.ViewStage`; + it immediately writes the requested changes to the underlying + dataset. + """ + + # The `keep()` operation below will delete samples, so we must sync + # deletions to the source dataset first + self._sync_source(update=False, delete=True) + + super().keep() + + def keep_fields(self): + """Deletes any sample fields that have been excluded in this view from + the samples of the underlying dataset. + + .. note:: + + This method is not a :class:`fiftyone.core.stages.ViewStage`; + it immediately writes the requested changes to the underlying + dataset. + """ + self._sync_source_keep_fields() + + super().keep_fields() + + def keep_frames(self): + """For each sample in the view, deletes all frames that are **not** in + the view from the underlying dataset. + + .. note:: + + This method is not a :class:`fiftyone.core.stages.ViewStage`; + it immediately writes the requested changes to the underlying + dataset. + """ + self._sync_source_keep_frames() + + super().keep_frames() + + def reload(self): + """Reloads the view. + + Note that :class:`MaterializedSampleView` instances are not singletons, + so any in-memory samples extracted from this view will not be updated + by calling this method. + """ + self._source_collection.reload() + + # + # Regenerate the materialized dataset + # + # This assumes that calling `load_view()` when the current materialized + # dataset has been deleted will cause a new one to be generated + # + self._materialized_dataset.delete() + _view = self._materialize_stage.load_view(self._source_collection) + self._materialized_dataset = _view._materialized_dataset + + super().reload() + + def _set_labels(self, field_name, sample_ids, label_docs): + super()._set_labels(field_name, sample_ids, label_docs) + + self._sync_source(fields=[field_name], ids=sample_ids) + + def _delete_labels(self, ids, fields=None): + super()._delete_labels(ids, fields=fields) + + self._source_collection._delete_labels(ids, fields=fields) + + def _bulk_write( + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, + ): + self._materialized_dataset._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) + + self._sync_source_schema() + + self._source_collection._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) + + def _save_sample(self, sample, sample_ops=None, frame_ops=None): + if sample_ops: + foo.bulk_write( + sample_ops, self._materialized_dataset._sample_collection + ) + + if frame_ops: + foo.bulk_write( + frame_ops, self._materialized_dataset._frame_collection + ) + + self._sync_source_sample( + sample, sample_ops=sample_ops, frame_ops=frame_ops + ) + + def _sync_source_sample(self, sample, sample_ops=None, frame_ops=None): + self._sync_source_schema() + + if sample_ops is None and frame_ops is None: + dst_dataset = self._source_collection._root_dataset + + match = {"_id": sample._id} + updates = sample.to_mongo_dict() + dst_dataset._sample_collection.update_one(match, {"$set": updates}) + + if sample.media_type == fom.VIDEO: + src_coll = self._materialized_dataset._frame_collection + dst_coll_name = dst_dataset._frame_collection_name + pipeline = [ + {"$match": {"_sample_id": sample._id}}, + { + "$merge": { + "into": dst_coll_name, + "whenMatched": "replace", + } + }, + ] + foo.aggregate(src_coll, pipeline) + else: + if sample_ops: + self._source_collection._bulk_write( + sample_ops, ids=[sample.id] + ) + + if frame_ops: + self._source_collection._bulk_write( + frame_ops, sample_ids=[sample.id], frames=True + ) + + def _sync_source(self, fields=None, ids=None, update=True, delete=False): + has_frame_fields = self._has_frame_fields() + + if has_frame_fields and fields is not None: + sample_fields, frame_fields = fou.split_frame_fields(fields) + else: + sample_fields, frame_fields = fields, None + + dst_dataset = self._source_collection._root_dataset + + if update: + self._sync_source_schema(fields=fields) + + if fields is None or sample_fields: + pipeline = [] + + if ids is not None: + pipeline.append( + { + "$match": { + "_id": {"$in": [ObjectId(_id) for _id in ids]} + } + } + ) + + if sample_fields is not None: + project = {f: True for f in sample_fields} + project["_id"] = True + pipeline.append({"$project": project}) + + pipeline.append( + { + "$merge": { + "into": dst_dataset._sample_collection_name, + "on": "_id", + "whenMatched": "merge", + "whenNotMatched": "discard", + } + } + ) + + self._materialized_dataset._aggregate(pipeline=pipeline) + + if has_frame_fields and (fields is None or frame_fields): + pipeline = [] + post_pipeline = [] + + if ids is not None: + pipeline.append( + { + "$match": { + "_id": {"$in": [ObjectId(_id) for _id in ids]} + } + } + ) + + if frame_fields is not None: + project = {f: True for f in frame_fields} + project["_sample_id"] = True + project["frame_number"] = True + post_pipeline.append({"$project": project}) + + post_pipeline.append( + { + "$merge": { + "into": dst_dataset._frame_collection_name, + "on": ["_sample_id", "frame_number"], + "whenMatched": "merge", + "whenNotMatched": "discard", + } + } + ) + + self._materialized_dataset._aggregate( + pipeline=pipeline, + frames_only=True, + post_pipeline=post_pipeline, + ) + + if delete: + # It's okay to pass a materialized view to `dst_dataset` because + # they share sample IDs + dst_dataset._keep(view=self) + + def _sync_source_field_schema(self, path): + field = self.get_field(path) + if field is None: + return + + _path, is_frame_field = self._handle_frame_field(path) + + dst_dataset = self._source_collection._dataset + if is_frame_field: + dst_dataset._merge_frame_field_schema({_path: field}) + else: + dst_dataset._merge_sample_field_schema({path: field}) + + if self._source_collection._is_generated: + self._source_collection._sync_source_field_schema(path) + + def _sync_source_schema(self, fields=None, delete=False): + has_frame_fields = self._has_frame_fields() + + if has_frame_fields and fields is not None: + sample_fields, frame_fields = fou.split_frame_fields(fields) + else: + sample_fields, frame_fields = fields, None + + dst_dataset = self._source_collection._root_dataset + src_schema = self._source_collection.get_field_schema() + + if delete: + schema = self.get_field_schema() + else: + schema = self._materialized_dataset.get_field_schema() + + if has_frame_fields: + if delete: + frame_schema = self.get_frame_field_schema() + else: + frame_schema = ( + self._materialized_dataset.get_frame_field_schema() + ) + + src_frame_schema = self._source_collection.get_frame_field_schema() + + add_sample_fields = [] + del_sample_fields = [] + add_frame_fields = [] + del_frame_fields = [] + + if fields is not None: + # We're syncing specific fields; if they are not present in source + # collection, add them + + for field_name in sample_fields: + if field_name not in src_schema: + add_sample_fields.append(field_name) + else: + # We're syncing all fields; add any missing fields to source + # collection and, if requested, delete any source fields that + # aren't in this view + + for field_name in schema.keys(): + if field_name not in src_schema: + add_sample_fields.append(field_name) + + if delete: + for field_name in src_schema.keys(): + if field_name not in schema: + del_sample_fields.append(field_name) + + if has_frame_fields: + if fields is not None: + # We're syncing specific fields; if they are not present in + # source collection, add them + + for field_name in frame_fields: + if field_name not in src_frame_schema: + add_frame_fields.append(field_name) + else: + # We're syncing all fields; add any missing fields to source + # collection and, if requested, delete any source fields that + # aren't in this view + + for field_name in frame_schema.keys(): + if field_name not in src_frame_schema: + add_frame_fields.append(field_name) + + if delete: + for field_name in src_frame_schema.keys(): + if field_name not in frame_schema: + del_frame_fields.append(field_name) + + for field_name in add_sample_fields: + field_kwargs = foo.get_field_kwargs(schema[field_name]) + dst_dataset.add_sample_field(field_name, **field_kwargs) + + for field_name in add_frame_fields: + field_kwargs = foo.get_field_kwargs(frame_schema[field_name]) + dst_dataset.add_frame_field(field_name, **field_kwargs) + + if delete and del_sample_fields: + dst_dataset.delete_sample_fields(del_sample_fields) + + if delete and del_frame_fields: + dst_dataset.delete_frame_fields(del_frame_fields) + + def _sync_source_keep_fields(self): + schema = self.get_field_schema() + src_schema = self._source_collection.get_field_schema() + + if self._has_frame_fields(): + p = self._FRAMES_PREFIX + + frame_schema = self.get_frame_field_schema() + schema.update({p + k: v for k, v in frame_schema.items()}) + + src_frame_schema = self._source_collection.get_frame_field_schema() + src_schema.update({p + k: v for k, v in src_frame_schema.items()}) + + del_fields = set(src_schema.keys()) - set(schema.keys()) + if del_fields: + self._source_collection.exclude_fields(del_fields).keep_fields() + + def _sync_source_keep_frames(self): + # It's okay to pass a materialized view to `dst_dataset` because they + # share sample IDs and frame numbers + dst_dataset = self._source_collection._dataset + dst_dataset._keep_frames(view=self) + + +def materialize_view(sample_collection, name=None, persistent=False): + """Creates a dataset that contains a materialized copy of the given + collection. + + Args: + sample_collection: a + :class:`fiftyone.core.collections.SampleCollection` + name (None): a name for the dataset + persistent (False): whether the dataset should persist in the database + after the session terminates + + Returns: + a :class:`fiftyone.core.dataset.Dataset` + """ + dataset = sample_collection._root_dataset + if isinstance(sample_collection, fov.DatasetView): + view = sample_collection + else: + # Materializing an entire dataset is a bit weird, but we'll allow it + view = sample_collection.view() + + return dataset._clone( + name=name, persistent=persistent, view=view, materialized=True + ) diff --git a/fiftyone/core/patches.py b/fiftyone/core/patches.py index 197bae101e8..9c23d2a93d6 100644 --- a/fiftyone/core/patches.py +++ b/fiftyone/core/patches.py @@ -9,6 +9,7 @@ from copy import deepcopy from bson import ObjectId +from pymongo import UpdateOne, UpdateMany import eta.core.utils as etau @@ -37,10 +38,13 @@ def _frame_id(self): return ObjectId(self._doc.frame_id) def _save(self, deferred=False): - sample_ops, frame_ops = super()._save(deferred=deferred) + sample_ops, frame_ops = super()._save(deferred=True) if not deferred: - self._view._sync_source_sample(self) + self._view._save_sample( + self, sample_ops=sample_ops, frame_ops=frame_ops + ) + return None, [] return sample_ops, frame_ops @@ -346,7 +350,59 @@ def reload(self): super().reload() - def _sync_source_sample(self, sample): + def _check_for_field_edits(self, ops, fields): + updated_fields = set() + + for op in ops: + if isinstance(op, (UpdateOne, UpdateMany)): + updated_fields.update(op._doc.get("$set", {}).keys()) + updated_fields.update(op._doc.get("$unset", {}).keys()) + + for field in list(updated_fields): + chunks = field.split(".") + for i in range(1, len(chunks)): + updated_fields.add(".".join(chunks[:i])) + + return bool(updated_fields & set(fields)) + + def _bulk_write( + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, + ): + self._patches_dataset._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) + + if self._check_for_field_edits(ops, self._label_fields): + self._sync_source(fields=self._label_fields, ids=ids) + self._source_collection._dataset._reload_docs(ids=ids) + + def _save_sample(self, sample, sample_ops=None, frame_ops=None): + if sample_ops: + foo.bulk_write( + sample_ops, self._patches_dataset._sample_collection + ) + + self._sync_source_sample( + sample, sample_ops=sample_ops, frame_ops=frame_ops + ) + + def _sync_source_sample(self, sample, sample_ops=None, frame_ops=None): + if sample_ops is not None and not self._check_for_field_edits( + sample_ops, self._label_fields + ): + return + for field in self._label_fields: self._sync_source_sample_field(sample, field) diff --git a/fiftyone/core/stages.py b/fiftyone/core/stages.py index a71272bf15b..6eea467774d 100644 --- a/fiftyone/core/stages.py +++ b/fiftyone/core/stages.py @@ -27,6 +27,7 @@ import fiftyone.core.frame as fofr import fiftyone.core.groups as fog import fiftyone.core.labels as fol +import fiftyone.core.materialize as foma import fiftyone.core.media as fom from fiftyone.core.odm.document import MongoEngineBaseDocument import fiftyone.core.sample as fos @@ -5530,6 +5531,74 @@ def _params(cls): ] +class Materialize(ViewStage): + """Materializes the current view into a temporary database collection. + + Apply this stage to an expensive view (eg an unindexed filtering operation + on a large dataset) if you plan to perform multiple downstream operations + on the view. + + Examples:: + + import fiftyone as fo + import fiftyone.zoo as foz + from fiftyone import ViewField as F + + dataset = foz.load_zoo_dataset("quickstart") + + view = dataset.filter_labels("ground_truth", F("label") == "cat") + + stage = fo.Materialize() + materialized_view = view.add_stage(stage) + + print(view.count("ground_truth.detections")) + print(materialized_view.count("ground_truth.detections")) + """ + + def __init__(self, _state=None): + self._state = _state + + @property + def has_view(self): + return True + + def load_view(self, sample_collection): + state = { + "dataset": sample_collection.dataset_name, + "stages": sample_collection.view()._serialize(include_uuids=False), + } + + last_state = deepcopy(self._state) + if last_state is not None: + name = last_state.pop("name", None) + else: + name = None + + if state != last_state or not fod.dataset_exists(name): + materialized_dataset = foma.materialize_view(sample_collection) + + # Other views may use the same generated dataset, so reuse the old + # name if possible + if name is not None and state == last_state: + materialized_dataset.name = name + + state["name"] = materialized_dataset.name + self._state = state + else: + materialized_dataset = fod.load_dataset(name) + + return foma.MaterializedView( + sample_collection, self, materialized_dataset + ) + + def _kwargs(self): + return [["_state", self._state]] + + @classmethod + def _params(self): + return [{"name": "_state", "type": "NoneType|json", "default": "None"}] + + class Mongo(ViewStage): """A view stage defined by a raw MongoDB aggregation pipeline. @@ -8627,6 +8696,7 @@ def repr_ViewExpression(self, expr, level): MatchFrames, MatchLabels, MatchTags, + Materialize, Mongo, Select, SelectBy, diff --git a/fiftyone/core/video.py b/fiftyone/core/video.py index 89a3f8b4606..2b26a025647 100644 --- a/fiftyone/core/video.py +++ b/fiftyone/core/video.py @@ -11,7 +11,7 @@ import os from bson import ObjectId -from pymongo import UpdateOne +from pymongo import UpdateOne, UpdateMany import eta.core.utils as etau @@ -54,10 +54,13 @@ def _sample_id(self): return ObjectId(self._doc.sample_id) def _save(self, deferred=False): - sample_ops, frame_ops = super()._save(deferred=deferred) + sample_ops, frame_ops = super()._save(deferred=True) if not deferred: - self._view._sync_source_sample(self) + self._view._save_sample( + self, sample_ops=sample_ops, frame_ops=frame_ops + ) + return None, [] return sample_ops, frame_ops @@ -327,29 +330,93 @@ def _delete_labels(self, ids, fields=None): self._source_collection._delete_labels(ids, fields=frame_fields) - def _sync_source_sample(self, sample): - self._sync_source_schema() - - dst_dataset = self._source_collection._root_dataset + def _prune_sample_only_field_updates(self, ops): sample_only_fields = self._get_sample_only_fields( include_private=True, use_db_fields=True ) - updates = { - k: v - for k, v in sample.to_mongo_dict().items() - if k not in sample_only_fields - } + for op in ops: + if isinstance(op, (UpdateOne, UpdateMany)): + sets = op._doc.get("$set", None) + if sets: + for f in sample_only_fields: + sets.pop(f, None) - if not updates: - return + unsets = op._doc.get("$unset", None) + if unsets: + for f in sample_only_fields: + unsets.pop(f, None) + + def _bulk_write( + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, + ): + self._frames_dataset._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) - match = { - "_sample_id": sample._sample_id, - "frame_number": sample.frame_number, - } + self._sync_source_schema() + self._prune_sample_only_field_updates(ops) + + self._source_collection._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=True, + ordered=ordered, + progress=progress, + ) + + def _save_sample(self, sample, sample_ops=None, frame_ops=None): + if sample_ops: + foo.bulk_write(sample_ops, self._frames_dataset._sample_collection) + + self._sync_source_sample( + sample, sample_ops=sample_ops, frame_ops=frame_ops + ) - dst_dataset._frame_collection.update_one(match, {"$set": updates}) + def _sync_source_sample(self, sample, sample_ops=None, frame_ops=None): + self._sync_source_schema() + + if sample_ops is None: + dst_dataset = self._source_collection._root_dataset + sample_only_fields = self._get_sample_only_fields( + include_private=True, use_db_fields=True + ) + + updates = { + k: v + for k, v in sample.to_mongo_dict().items() + if k not in sample_only_fields + } + + if not updates: + return + + match = { + "_sample_id": sample._sample_id, + "frame_number": sample.frame_number, + } + + dst_dataset._frame_collection.update_one(match, {"$set": updates}) + else: + self._prune_sample_only_field_updates(sample_ops) + + self._source_collection._bulk_write( + sample_ops, + sample_ids=[sample.id], + frames=True, + ) def _sync_source(self, fields=None, ids=None, update=True, delete=False): dst_dataset = self._source_collection._root_dataset diff --git a/fiftyone/core/view.py b/fiftyone/core/view.py index a7ba59c9832..c969b86eec0 100644 --- a/fiftyone/core/view.py +++ b/fiftyone/core/view.py @@ -199,6 +199,10 @@ def _is_frames(self): def _is_clips(self): return self._dataset._is_clips + @property + def _is_materialized(self): + return self._dataset._is_materialized + @property def _is_dynamic_groups(self): return self._outputs_dynamic_groups() @@ -1320,8 +1324,8 @@ def keep_fields(self): self._dataset._keep_fields(view=self) def keep_frames(self): - """For each sample in the view, deletes all frames labels that are - **not** in the view from the underlying dataset. + """For each sample in the view, deletes all frames that are **not** in + the view from the underlying dataset. .. note:: diff --git a/tests/unittests/materialize_tests.py b/tests/unittests/materialize_tests.py new file mode 100644 index 00000000000..f3b28903bb4 --- /dev/null +++ b/tests/unittests/materialize_tests.py @@ -0,0 +1,378 @@ +""" +FiftyOne materialized view-related unit tests. + +| Copyright 2017-2024, Voxel51, Inc. +| `voxel51.com `_ +| +""" +from copy import deepcopy + +from bson import ObjectId +import unittest + +import fiftyone as fo +from fiftyone import ViewField as F + +from decorators import drop_datasets + + +class MaterializeTests(unittest.TestCase): + @drop_datasets + def test_materialize(self): + dataset = fo.Dataset() + + sample1 = fo.Sample( + filepath="video1.mp4", + tags=["test"], + weather="sunny", + ) + sample1.frames[1] = fo.Frame() + sample1.frames[2] = fo.Frame( + ground_truth=fo.Detections( + detections=[ + fo.Detection(label="cat"), + fo.Detection(label="dog"), + ] + ), + ) + sample1.frames[3] = fo.Frame() + + sample2 = fo.Sample( + filepath="video2.mp4", + tags=["test"], + weather="cloudy", + ) + sample2.frames[1] = fo.Frame( + ground_truth=fo.Detections( + detections=[ + fo.Detection(label="dog"), + fo.Detection(label="rabbit"), + ] + ), + ) + sample2.frames[3] = fo.Frame() + sample2.frames[5] = fo.Frame() + + sample3 = fo.Sample( + filepath="video3.mp4", + tags=["test"], + weather="rainy", + ) + + dataset.add_samples([sample1, sample2, sample3]) + + view = ( + dataset.limit(2) + .match_frames(F("frame_number") <= 2, omit_empty=False) + .materialize() + ) + + self.assertSetEqual( + set(view.get_field_schema().keys()), + { + "id", + "filepath", + "metadata", + "tags", + "created_at", + "last_modified_at", + "weather", + }, + ) + + self.assertSetEqual( + set(view.get_frame_field_schema().keys()), + { + "id", + "frame_number", + "created_at", + "last_modified_at", + "ground_truth", + }, + ) + + self.assertEqual( + view.get_field("metadata").document_type, + fo.VideoMetadata, + ) + + self.assertSetEqual( + set(view.select_fields().get_field_schema().keys()), + { + "id", + "filepath", + "metadata", + "tags", + "created_at", + "last_modified_at", + }, + ) + + self.assertSetEqual( + set(view.select_fields().get_frame_field_schema().keys()), + { + "id", + "frame_number", + "created_at", + "last_modified_at", + }, + ) + + with self.assertRaises(ValueError): + view.exclude_fields("tags") # can't exclude default field + + with self.assertRaises(ValueError): + view.exclude_fields( + "frames.frame_number" + ) # can't exclude default field + + index_info = view.get_index_information() + indexes = view.list_indexes() + default_indexes = { + "id", + "filepath", + "created_at", + "last_modified_at", + "frames.id", + "frames._sample_id_1_frame_number_1", + "frames.created_at", + "frames.last_modified_at", + } + + self.assertSetEqual(set(index_info.keys()), default_indexes) + self.assertSetEqual(set(indexes), default_indexes) + + with self.assertRaises(ValueError): + view.drop_index("id") # can't drop default index + + with self.assertRaises(ValueError): + view.drop_index("filepath") # can't drop default index + + with self.assertRaises(ValueError): + view.drop_index("frames.created_at") # can't drop default index + + self.assertEqual(len(view), 2) + self.assertEqual(view.count("frames"), 3) + + sample = view.first() + self.assertIsInstance(sample.id, str) + self.assertIsInstance(sample._id, ObjectId) + + for _id in view.values("id"): + self.assertIsInstance(_id, str) + + for oid in view.values("_id"): + self.assertIsInstance(oid, ObjectId) + + for _id in view.values("frames.id", unwind=True): + self.assertIsInstance(_id, str) + + for oid in view.values("frames._id", unwind=True): + self.assertIsInstance(oid, ObjectId) + + self.assertDictEqual(dataset.count_sample_tags(), {"test": 3}) + self.assertDictEqual(view.count_sample_tags(), {"test": 2}) + + view.tag_samples("foo") + + self.assertEqual(view.count_sample_tags()["foo"], 2) + self.assertEqual(dataset.count_sample_tags()["foo"], 2) + + view.untag_samples("foo") + + self.assertNotIn("foo", view.count_sample_tags()) + self.assertNotIn("foo", dataset.count_sample_tags()) + + view.tag_labels("test") + + self.assertDictEqual(view.count_label_tags(), {"test": 4}) + self.assertDictEqual(dataset.count_label_tags(), {"test": 4}) + + view.select_labels(tags="test").untag_labels("test") + + self.assertDictEqual(view.count_label_tags(), {}) + self.assertDictEqual(dataset.count_label_tags(), {}) + + view2 = view.limit(1).set_field( + "frames.ground_truth.detections.label", F("label").upper() + ) + + self.assertDictEqual( + view.count_values("frames.ground_truth.detections.label"), + {"cat": 1, "dog": 2, "rabbit": 1}, + ) + self.assertDictEqual( + view2.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1}, + ) + self.assertDictEqual( + dataset.count_values("frames.ground_truth.detections.label"), + {"cat": 1, "dog": 2, "rabbit": 1}, + ) + + values = { + _id: v + for _id, v in zip( + *view2.values( + [ + "frames.ground_truth.detections.id", + "frames.ground_truth.detections.label", + ], + unwind=True, + ) + ) + } + view.set_label_values( + "frames.ground_truth.detections.also_label", values + ) + + self.assertEqual( + view.count("frames.ground_truth.detections.also_label"), 2 + ) + self.assertEqual( + dataset.count("frames.ground_truth.detections.also_label"), 2 + ) + self.assertDictEqual( + view.count_values("frames.ground_truth.detections.also_label"), + dataset.count_values("frames.ground_truth.detections.also_label"), + ) + + view2.save() + + self.assertEqual(len(view), 2) + self.assertEqual(dataset.values(F("frames").length()), [3, 3, 0]) + self.assertDictEqual( + view.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1, "dog": 1, "rabbit": 1}, + ) + self.assertDictEqual( + dataset.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1, "dog": 1, "rabbit": 1}, + ) + + view2.keep() + view2.keep_frames() + view.reload() + + self.assertEqual(len(view), 1) + self.assertEqual(dataset.values(F("frames").length()), [2]) + self.assertDictEqual( + view.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1}, + ) + self.assertDictEqual( + dataset.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1}, + ) + + sample = view.exclude_fields("weather").first() + + sample["foo"] = "bar" + sample.save() + + self.assertIn("foo", view.get_field_schema()) + self.assertIn("foo", dataset.get_field_schema()) + self.assertIn("weather", view.get_field_schema()) + self.assertIn("weather", dataset.get_field_schema()) + self.assertEqual(view.count_values("foo")["bar"], 1) + self.assertEqual(dataset.count_values("foo")["bar"], 1) + self.assertDictEqual(view.count_values("weather"), {"sunny": 1}) + self.assertDictEqual(dataset.count_values("weather"), {"sunny": 1}) + + sample = view.exclude_fields("frames.ground_truth").first() + frame = sample.frames.first() + + frame["spam"] = "eggs" + sample.save() + + self.assertIn("spam", view.get_frame_field_schema()) + self.assertIn("spam", dataset.get_frame_field_schema()) + self.assertIn("ground_truth", view.get_frame_field_schema()) + self.assertIn("ground_truth", dataset.get_frame_field_schema()) + self.assertEqual(view.count_values("frames.spam")["eggs"], 1) + self.assertEqual(dataset.count_values("frames.spam")["eggs"], 1) + self.assertDictEqual( + view.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1}, + ) + self.assertDictEqual( + dataset.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1}, + ) + + dataset.untag_samples("test") + view.reload() + + self.assertEqual(dataset.count_sample_tags(), {}) + self.assertEqual(view.count_sample_tags(), {}) + + view.select_fields().keep_fields() + + self.assertNotIn("weather", view.get_field_schema()) + self.assertNotIn("weather", dataset.get_field_schema()) + self.assertNotIn("ground_truth", view.get_frame_field_schema()) + self.assertNotIn("ground_truth", dataset.get_frame_field_schema()) + + sample_view = view.first() + with self.assertRaises(KeyError): + sample_view["weather"] + + frame_view = sample_view.frames.first() + with self.assertRaises(KeyError): + frame_view["ground_truth"] + + # Test saving a materialized view + + self.assertIsNone(view.name) + + view_name = "test" + dataset.save_view(view_name, view) + self.assertEqual(view.name, view_name) + self.assertTrue(view.is_saved) + + also_view = dataset.load_saved_view(view_name) + self.assertEqual(view, also_view) + self.assertEqual(also_view.name, view_name) + self.assertTrue(also_view.is_saved) + + still_view = deepcopy(view) + self.assertEqual(still_view.name, view_name) + self.assertTrue(still_view.is_saved) + self.assertEqual(still_view, view) + + @drop_datasets + def test_materialize_save_context(self): + dataset = fo.Dataset() + + sample1 = fo.Sample(filepath="video1.mp4") + sample1.frames[1] = fo.Frame(filepath="frame11.jpg") + sample1.frames[2] = fo.Frame(filepath="frame12.jpg") + sample1.frames[3] = fo.Frame(filepath="frame13.jpg") + + sample2 = fo.Sample(filepath="video2.mp4") + + sample3 = fo.Sample(filepath="video3.mp4") + sample3.frames[1] = fo.Frame(filepath="frame31.jpg") + + dataset.add_samples([sample1, sample2, sample3]) + + view = ( + dataset.limit(2) + .match_frames(F("frame_number") != 2, omit_empty=False) + .materialize() + ) + + for sample in view.iter_samples(autosave=True): + sample["foo"] = "bar" + for frame in sample.frames.values(): + frame["foo"] = "bar" + + self.assertEqual(view.count("foo"), 2) + self.assertEqual(dataset.count("foo"), 2) + self.assertEqual(view.count("frames.foo"), 2) + self.assertEqual(dataset.count("frames.foo"), 2) + + +if __name__ == "__main__": + fo.config.show_progress_bars = False + unittest.main(verbosity=2)