diff --git a/app/packages/core/src/components/Actions/Tagger.tsx b/app/packages/core/src/components/Actions/Tagger.tsx index 2850206d9aa..667bacc4e0f 100644 --- a/app/packages/core/src/components/Actions/Tagger.tsx +++ b/app/packages/core/src/components/Actions/Tagger.tsx @@ -348,7 +348,11 @@ const useTagCallback = ( fos.isOrderedDynamicGroup ); - const slices = await snapshot.getPromise(fos.currentSlices(modal)); + const mode = await snapshot.getPromise(groupStatistics(modal)); + const currentSlices = await snapshot.getPromise( + fos.currentSlices(modal) + ); + const slices = await snapshot.getPromise(fos.groupSlices); const { samples } = await getFetchFunction()("POST", "/tag", { ...tagParameters({ activeFields: await snapshot.getPromise( @@ -363,9 +367,10 @@ const useTagCallback = ( isGroup && !isNonNestedDynamicGroup ? { id: modal ? await snapshot.getPromise(groupId) : null, - slices, + mode, + currentSlices, slice: await snapshot.getPromise(fos.groupSlice), - mode: await snapshot.getPromise(groupStatistics(modal)), + slices, } : null, modal, diff --git a/app/packages/core/src/components/Actions/utils.tsx b/app/packages/core/src/components/Actions/utils.tsx index 6a29ec2d4ad..82d2284b6de 100644 --- a/app/packages/core/src/components/Actions/utils.tsx +++ b/app/packages/core/src/components/Actions/utils.tsx @@ -84,9 +84,10 @@ export const tagStatistics = selectorFamily< get(isGroup) && get(fos.groupField) ? { id: modal ? get(groupId) : null, - slices: get(fos.currentSlices(modal)), - slice: get(fos.currentSlice(modal)), + currentSlices: get(fos.currentSlices(modal)), mode: get(groupStatistics(modal)), + slice: get(fos.currentSlice(modal)), + slices: get(fos.groupSlices), } : null, hiddenLabels: get(fos.hiddenLabelsArray), @@ -131,18 +132,7 @@ export const tagStats = selectorFamily< get: ({ modal, labels }) => ({ get }) => { - const data = Object.keys( - get( - labels - ? fos.labelTagCounts({ modal: false, extended: false }) - : fos.sampleTagCounts({ modal: false, extended: false }) - ) - ).map((t) => [t, 0]); - - return { - ...Object.fromEntries(data), - ...get(tagStatistics({ modal, labels })).tags, - }; + return get(tagStatistics({ modal, labels })).tags; }, }); @@ -166,6 +156,7 @@ export const tagParameters = ({ activeFields: string[]; groupData: { id: string | null; + currentSlices: string[] | null; slice: string | null; slices: string[] | null; mode: "group" | "slice"; @@ -174,8 +165,11 @@ export const tagParameters = ({ sampleId: string | null; }) => { const shouldShowCurrentSample = - params.modal && selectedSamples.size == 0 && hiddenLabels.length == 0; + params.modal && selectedSamples.size === 0 && hiddenLabels.length === 0; const groups = groupData?.mode === "group"; + if (groupData && !groups) { + groupData.slices = groupData.currentSlices; + } const getSampleIds = () => { if (shouldShowCurrentSample && !groups) { @@ -186,9 +180,11 @@ export const tagParameters = ({ return [...new Set(selectedLabels.map((l) => l.sampleId))]; } return [sampleId]; - } else if (selectedSamples.size) { + } + if (selectedSamples.size) { return [...selectedSamples]; } + return null; }; @@ -196,7 +192,7 @@ export const tagParameters = ({ ...params, label_fields: activeFields, target_labels: targetLabels, - slices: !groups ? groupData?.slices : null, + slices: groupData?.slices, slice: groupData?.slice, group_id: params.modal ? groupData?.id : null, sample_ids: getSampleIds(), diff --git a/app/packages/state/src/recoil/aggregations.ts b/app/packages/state/src/recoil/aggregations.ts index 77de88fdb68..86befde765a 100644 --- a/app/packages/state/src/recoil/aggregations.ts +++ b/app/packages/state/src/recoil/aggregations.ts @@ -5,7 +5,13 @@ import { graphQLSelectorFamily } from "recoil-relay"; import type { ResponseFrom } from "../utils"; import { refresher } from "./atoms"; import * as filterAtoms from "./filters"; -import { currentSlices, groupId, groupSlice, groupStatistics } from "./groups"; +import { + currentSlices, + groupId, + groupSlice, + groupSlices, + groupStatistics, +} from "./groups"; import { sidebarSampleId } from "./modal"; import { RelayEnvironmentKey } from "./relay"; import * as schemaAtoms from "./schema"; @@ -76,7 +82,7 @@ export const aggregationQuery = graphQLSelectorFamily< paths, mixed, sampleIds, - slices: mixed ? null : get(currentSlices(modal)), // when mixed, slice is not needed + slices: mixed ? get(groupSlices) : get(currentSlices(modal)), slice: get(groupSlice), view: customView ? customView : !root ? get(viewAtoms.view) : [], }; @@ -150,15 +156,22 @@ export const modalAggregationPaths = selectorFamily({ const isFramesPath = frames.some((p) => params.path.startsWith(p)); let paths = isFramesPath ? frames - : get(schemaAtoms.labelFields({ space: State.SPACE.SAMPLE })).map( - (path) => get(schemaAtoms.expandPath(path)) - ); + : [ + ...get(schemaAtoms.labelFields({ space: State.SPACE.SAMPLE })).map( + (path) => get(schemaAtoms.expandPath(path)) + ), + ]; paths = paths .sort() .flatMap((p) => get(schemaAtoms.modalFilterFields(p))); const numeric = get(schemaAtoms.isNumericField(params.path)); + if (!isFramesPath && !numeric) { + // the modal currently requires a 'tags' aggregation + paths = ["tags", ...paths]; + } + if (params.mixed || get(groupId)) { paths = [ ...paths.filter((p) => { @@ -166,11 +179,6 @@ export const modalAggregationPaths = selectorFamily({ return numeric ? n : !n; }), ]; - - if (!numeric && !isFramesPath) { - // the modal currently requires a 'tags' aggregation - paths = ["tags", ...paths]; - } } return paths; diff --git a/app/packages/state/src/recoil/modal.ts b/app/packages/state/src/recoil/modal.ts index 6fca56f44b8..7635c3f0070 100644 --- a/app/packages/state/src/recoil/modal.ts +++ b/app/packages/state/src/recoil/modal.ts @@ -29,7 +29,7 @@ export const modalLooker = atom({ dangerouslyAllowMutability: true, }); -export const sidebarSampleId = selector({ +export const sidebarSampleId = selector({ key: "sidebarSampleId", get: ({ get }) => { if (get(shouldRenderImaVidLooker(true))) { @@ -41,7 +41,7 @@ export const sidebarSampleId = selector({ if (!isPlaying && !isSeeking && thisFrameNumber && sample) { // is the type incorrect? fix me - const id = sample?.id || sample?._id || sample?.sample?._id; + const id = sample?.id || sample?._id || (sample?.sample?._id as string); if (id) { return id; } diff --git a/app/packages/state/src/recoil/pathData/tags.ts b/app/packages/state/src/recoil/pathData/tags.ts index 5eb6005e586..b955395a6c2 100644 --- a/app/packages/state/src/recoil/pathData/tags.ts +++ b/app/packages/state/src/recoil/pathData/tags.ts @@ -1,5 +1,6 @@ import { selectorFamily } from "recoil"; import { aggregation } from "../aggregations"; +import { groupStatistics } from "../groups"; import * as schemaAtoms from "../schema"; export const labelTagCounts = selectorFamily< @@ -11,7 +12,14 @@ export const labelTagCounts = selectorFamily< ({ modal, extended }) => ({ get }) => { const data = get(schemaAtoms.labelPaths({})).map((path) => - get(aggregation({ extended, modal, path: `${path}.tags`, mixed: true })) + get( + aggregation({ + extended, + modal, + path: `${path}.tags`, + mixed: get(groupStatistics(modal)) === "group", + }) + ) ); const result = {}; @@ -45,7 +53,13 @@ export const sampleTagCounts = selectorFamily< get: (params) => ({ get }) => { - const data = get(aggregation({ ...params, path: "tags" })); + const data = get( + aggregation({ + ...params, + path: "tags", + mixed: get(groupStatistics(params.modal)) === "group", + }) + ); if (data.__typename !== "StringAggregation") { throw new Error("unexpected"); } diff --git a/fiftyone/server/aggregations.py b/fiftyone/server/aggregations.py index 7d41893f776..56dcd0c9645 100644 --- a/fiftyone/server/aggregations.py +++ b/fiftyone/server/aggregations.py @@ -127,10 +127,6 @@ async def aggregate_resolver( if form.sample_ids: view = fov.make_optimized_select_view(view, form.sample_ids) - if form.mixed and view.media_type == fom.GROUP and view.group_slices: - view = view.select_group_slices(_force_mixed=True) - view = fosv.get_extended_view(view, form.filters) - if form.hidden_labels: view = view.exclude_labels( [