Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QP sidebar filters to active slice for group datasets #5177

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions app/packages/state/src/recoil/queryPerformance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { graphQLSelectorFamily } from "recoil-relay";
import type { ResponseFrom } from "../utils";
import { config } from "./config";
import { getBrowserStorageEffectForKey } from "./customEffects";
import { groupSlice } from "./groups";
import { isLabelPath } from "./labels";
import { RelayEnvironmentKey } from "./relay";
import * as schemaAtoms from "./schema";
Expand All @@ -34,6 +35,7 @@ export const lightningQuery = graphQLSelectorFamily<
input: {
dataset: get(datasetName),
paths,
slice: get(groupSlice),
},
};
},
Expand Down Expand Up @@ -83,6 +85,8 @@ const indexesByPath = selector({

const { sampleIndexes: samples, frameIndexes: frames } = get(indexes);

console.log(samples);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove debugging console.log statement

Debug logging should not be committed to production code.

Apply this diff to remove the debugging statement:

-    console.log(samples);
-
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
console.log(samples);

const schema = gatherPaths(State.SPACE.SAMPLE);
const frameSchema = gatherPaths(State.SPACE.FRAME).map((p) =>
p.slice("frames.".length)
Expand Down
1 change: 1 addition & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ input LabelTagColorInput {
input LightningInput {
dataset: String!
paths: [LightningPathInput!]!
slice: String = null
}

input LightningPathInput {
Expand Down
6 changes: 0 additions & 6 deletions docs/source/user_guide/app.rst
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,6 @@ perform initial filters on:
# Note: it is faster to declare indexes before adding samples
dataset.add_samples(...)

fo.app_config.default_query_performance = True

session = fo.launch_app(dataset)

.. note::
Expand Down Expand Up @@ -521,8 +519,6 @@ compound index that includes the group slice name:
dataset.create_index("ground_truth.detections.label")
dataset.create_index([("group.name", 1), ("ground_truth.detections.label", 1)])

fo.app_config.default_query_performance = True

session = fo.launch_app(dataset)

For datasets with a small number of fields, you can index all fields by adding
Expand All @@ -538,8 +534,6 @@ a single
dataset = foz.load_zoo_dataset("quickstart")
dataset.create_index("$**")

fo.app_config.default_query_performance = True

session = fo.launch_app(dataset)

.. warning::
Expand Down
33 changes: 25 additions & 8 deletions fiftyone/server/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from bson import ObjectId
from dataclasses import asdict, dataclass
from datetime import date, datetime
import math
import typing as t

import asyncio
Expand Down Expand Up @@ -46,6 +45,7 @@ class LightningPathInput:
class LightningInput:
dataset: str
paths: t.List[LightningPathInput]
slice: t.Optional[str] = None


@gql.interface
Expand Down Expand Up @@ -138,7 +138,13 @@ async def lightning_resolver(
for collection, sublist in zip(collections, queries)
for item in sublist
]
result = await _do_async_pooled_queries(dataset, flattened)

filter = (
{f"{dataset.group_field}.name": input.slice}
if dataset.group_field and input.slice
else None
)
result = await _do_async_pooled_queries(dataset, flattened, filter)

results = []
offset = 0
Expand Down Expand Up @@ -293,10 +299,11 @@ async def _do_async_pooled_queries(
queries: t.List[
t.Tuple[AsyncIOMotorCollection, t.Union[DistinctQuery, t.List[t.Dict]]]
],
filter: t.Optional[t.Mapping[str, str]],
):
return await asyncio.gather(
*[
_do_async_query(dataset, collection, query)
_do_async_query(dataset, collection, query, filter)
for collection, query in queries
]
)
Expand All @@ -306,25 +313,31 @@ async def _do_async_query(
dataset: fo.Dataset,
collection: AsyncIOMotorCollection,
query: t.Union[DistinctQuery, t.List[t.Dict]],
filter: t.Optional[t.Mapping[str, str]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add type checking before modifying the query list

The current implementation assumes query is always a list when applying the filter. Add type checking to ensure safe operation.

    if filter:
+       if not isinstance(query, list):
+           raise TypeError("Expected query to be a list for filter application")
        query.insert(0, {"$match": filter})

Also applies to: 324-325

):
if isinstance(query, DistinctQuery):
if query.has_list and not query.filters:
return await _do_distinct_query(collection, query)
return await _do_distinct_query(collection, query, filter)

return await _do_distinct_pipeline(dataset, collection, query, filter)

return await _do_distinct_pipeline(dataset, collection, query)
if filter:
query.insert(0, {"$match": filter})

return [i async for i in collection.aggregate(query)]


async def _do_distinct_query(
collection: AsyncIOMotorCollection, query: DistinctQuery
collection: AsyncIOMotorCollection,
query: DistinctQuery,
filter: t.Optional[t.Mapping[str, str]],
):
match = None
if query.search:
match = query.search

try:
result = await collection.distinct(query.path)
result = await collection.distinct(query.path, filter)
except:
# too many results
return None
Expand All @@ -350,12 +363,16 @@ async def _do_distinct_pipeline(
dataset: fo.Dataset,
collection: AsyncIOMotorCollection,
query: DistinctQuery,
filter: t.Optional[t.Mapping[str, str]],
):
pipeline = []
if filter:
pipeline.append({"$match": filter})

if query.filters:
pipeline += get_view(dataset, filters=query.filters)._pipeline()

pipeline += [{"$sort": {query.path: 1}}]
pipeline.append({"$sort": {query.path: 1}})

if query.search:
if query.is_object_id_field:
Expand Down
119 changes: 108 additions & 11 deletions tests/unittests/lightning_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,91 @@ async def test_strings(self, dataset: fo.Dataset):
)


class TestGroupDatasetLightningQueries(unittest.IsolatedAsyncioTestCase):
@drop_async_dataset
async def test_group_dataset(self, dataset: fo.Dataset):
group = fo.Group()
one = fo.Sample(
classifications=fo.Classifications(
classifications=[fo.Classification(label="one")]
),
filepath="one.png",
group=group.element("one"),
numeric=1,
string="one",
)
two = fo.Sample(
classifications=fo.Classifications(
classifications=[fo.Classification(label="two")]
),
filepath="two.png",
group=group.element("two"),
numeric=2,
string="two",
)
dataset.add_samples([one, two])

query = """
query Query($input: LightningInput!) {
lightning(input: $input) {
... on IntLightningResult {
path
min
max
}
... on StringLightningResult {
path
values
}
}
}
"""

# only query "one" slice samples
result = await _execute(
query,
dataset,
(fo.IntField, fo.StringField),
["classifications.classifications.label", "numeric", "string"],
frames=False,
slice="one",
)

self.assertListEqual(
result.data["lightning"],
[
{
"path": "classifications.classifications.label",
"values": ["one"],
},
{"path": "numeric", "min": 1.0, "max": 1.0},
{"path": "string", "values": ["one"]},
],
)

# only query "two" slice samples
result = await _execute(
query,
dataset,
(fo.IntField, fo.StringField),
["classifications.classifications.label", "numeric", "string"],
frames=False,
slice="two",
)

self.assertListEqual(
result.data["lightning"],
[
{
"path": "classifications.classifications.label",
"values": ["two"],
},
{"path": "numeric", "min": 2.0, "max": 2.0},
{"path": "string", "values": ["two"]},
],
)


def _add_samples(dataset: fo.Dataset, *sample_data: t.List[t.Dict]):
samples = []
keys = set()
Expand All @@ -1067,7 +1152,12 @@ def _add_samples(dataset: fo.Dataset, *sample_data: t.List[t.Dict]):


async def _execute(
query: str, dataset: fo.Dataset, field: fo.Field, keys: t.Set[str]
query: str,
dataset: fo.Dataset,
field: fo.Field,
keys: t.Set[str],
frames=True,
slice: t.Optional[str] = None,
):
return await execute(
schema,
Expand All @@ -1076,25 +1166,32 @@ async def _execute(
"input": asdict(
LightningInput(
dataset=dataset.name,
paths=_get_paths(dataset, field, keys),
paths=_get_paths(dataset, field, keys, frames=frames),
slice=slice,
)
)
},
)


def _get_paths(
dataset: fo.Dataset, field_type: t.Type[fo.Field], keys: t.Set[str]
dataset: fo.Dataset,
field_type: t.Type[fo.Field],
keys: t.Set[str],
frames=True,
):
field_dict = dataset.get_field_schema(flat=True)
field_dict.update(
**{
f"frames.{path}": field
for path, field in dataset.get_frame_field_schema(
flat=True
).items()
}
)

if frames:
field_dict.update(
**{
f"frames.{path}": field
for path, field in dataset.get_frame_field_schema(
flat=True
).items()
}
)

paths: t.List[LightningPathInput] = []
for path in sorted(field_dict):
field = field_dict[path]
Expand Down
Loading