Skip to content

Commit

Permalink
[ML] update stats on refresh
Browse files Browse the repository at this point in the history
  • Loading branch information
darnautov committed Aug 19, 2020
1 parent e8a46b3 commit 5985295
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,11 @@ export const useRefreshAnalyticsList = (
subscriptions.push(
distinct$
.pipe(filter((state) => state === REFRESH_ANALYTICS_LIST_STATE.REFRESH))
.subscribe(() => typeof callback.onRefresh === 'function' && callback.onRefresh())
.subscribe(() => {
if (typeof callback.onRefresh === 'function') {
callback.onRefresh();
}
})
);
}

Expand All @@ -353,7 +357,7 @@ export const useRefreshAnalyticsList = (
return () => {
subscriptions.map((sub) => sub.unsubscribe());
};
}, []);
}, [callback.onRefresh]);

return {
refresh: () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
EuiText,
EuiHorizontalRule,
EuiFlexGroup,
EuiTextColor,
} from '@elastic/eui';
// @ts-ignore
import { formatDate } from '@elastic/eui/lib/services/format';
Expand Down Expand Up @@ -233,9 +234,11 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
<EuiFlexGroup>
<EuiFlexItem grow={false}>
<EuiTitle size={'xs'}>
<h5>
{i + 1}. {pipelineName}
</h5>
<EuiTextColor color="subdued">
<h5>
{i + 1}. {pipelineName}
</h5>
</EuiTextColor>
</EuiTitle>
</EuiFlexItem>
<EuiFlexItem>
Expand Down Expand Up @@ -267,7 +270,9 @@ export const ExpandedRow: FC<ExpandedRowProps> = ({ item }) => {
<EuiFlexGroup>
<EuiFlexItem grow={false}>
<EuiTitle size={'xxs'}>
<h6>{name}</h6>
<EuiTextColor color="subdued">
<h6>{name}</h6>
</EuiTextColor>
</EuiTitle>
</EuiFlexItem>
<EuiFlexItem>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,41 +78,65 @@ export const ModelsList: FC = () => {

const [modelsToDelete, setModelsToDelete] = useState<ModelItemFull[]>([]);

const [itemIdToExpandedRowMap, setItemIdToExpandedRowMap] = useState<Record<string, any>>({});

// Subscribe to the refresh observable to trigger reloading the model list.
useRefreshAnalyticsList({
isLoading: setIsLoading,
onRefresh: fetchData,
});
const [itemIdToExpandedRowMap, setItemIdToExpandedRowMap] = useState<Record<string, JSX.Element>>(
{}
);

/**
* Fetches inference trained models.
*/
async function fetchData() {
const fetchData = useCallback(async () => {
try {
const response = await inferenceApiService.getInferenceModel(undefined, {
with_pipelines: true,
size: 1000,
});
setItems(
response.map((v) => ({
...v,
...(typeof v.inference_config === 'object'
? { type: Object.keys(v.inference_config)[0] }

const newItems = [];
const expandedItemsToRefresh = [];

for (const model of response) {
const tableItem = {
...model,
...(typeof model.inference_config === 'object'
? { type: Object.keys(model.inference_config)[0] }
: {}),
}))
);
};
newItems.push(tableItem);

if (itemIdToExpandedRowMap[model.model_id]) {
expandedItemsToRefresh.push(tableItem);
}
}

setItems(newItems);

if (expandedItemsToRefresh.length > 0) {
await fetchModelsStats(expandedItemsToRefresh);

setItemIdToExpandedRowMap(
expandedItemsToRefresh.reduce((acc, item) => {
acc[item.model_id] = <ExpandedRow item={item as ModelItemFull} />;
return acc;
}, {} as Record<string, JSX.Element>)
);
}
} catch (error) {
toasts.addError(new Error(error.body.message), {
toasts.addError(new Error(error.body?.message), {
title: i18n.translate('xpack.ml.inference.modelsList.fetchFailedErrorMessage', {
defaultMessage: 'Models fetch failed',
}),
});
}
setIsLoading(false);
refreshAnalyticsList$.next(REFRESH_ANALYTICS_LIST_STATE.IDLE);
}
}, [itemIdToExpandedRowMap]);

// Subscribe to the refresh observable to trigger reloading the model list.
useRefreshAnalyticsList({
isLoading: setIsLoading,
onRefresh: fetchData,
});

const modelsStats: ModelsBarStats = useMemo(() => {
return {
Expand All @@ -129,34 +153,27 @@ export const ModelsList: FC = () => {
/**
* Fetches models stats and update the original object
*/
const fetchModelsStats = useCallback(
async (models: ModelItem[]) => {
const modelIdsToFetch = models
.filter((model) => model.stats === undefined)
.map((model) => model.model_id);

// no need to fetch
if (modelIdsToFetch.length === 0) return true;

try {
const {
trained_model_stats: modelsStatsResponse,
} = await inferenceApiService.getInferenceModelStats(modelIdsToFetch);
for (const { model_id: id, ...stats } of modelsStatsResponse) {
const model = models.find((m) => m.model_id === id);
model!.stats = stats;
}
return true;
} catch (error) {
toasts.addError(new Error(error.body.message), {
title: i18n.translate('xpack.ml.inference.modelsList.fetchModelStatsErrorMessage', {
defaultMessage: 'Fetch model stats failed',
}),
});
const fetchModelsStats = useCallback(async (models: ModelItem[]) => {
const modelIdsToFetch = models.map((model) => model.model_id);

try {
const {
trained_model_stats: modelsStatsResponse,
} = await inferenceApiService.getInferenceModelStats(modelIdsToFetch);

for (const { model_id: id, ...stats } of modelsStatsResponse) {
const model = models.find((m) => m.model_id === id);
model!.stats = stats;
}
},
[items]
);
return true;
} catch (error) {
toasts.addError(new Error(error.body.message), {
title: i18n.translate('xpack.ml.inference.modelsList.fetchModelStatsErrorMessage', {
defaultMessage: 'Fetch model stats failed',
}),
});
}
}, []);

/**
* Unique inference types from models
Expand Down

0 comments on commit 5985295

Please sign in to comment.