From d2a1d12fbac537f8a6c3b813ef7122c3fea74685 Mon Sep 17 00:00:00 2001 From: Dmitry Agapov Date: Tue, 6 Apr 2021 13:09:33 +0300 Subject: [PATCH] Integration with an internal training server (#2785) Co-authored-by: Boris Sekachev Co-authored-by: Nikita Manovich --- .vscode/settings.json | 2 +- cvat-core/src/api.js | 2 +- cvat-core/src/project.js | 23 ++ cvat-core/src/server-proxy.js | 123 ++++++ cvat-core/src/session.js | 112 +++++- cvat-ui/package-lock.json | 5 + cvat-ui/package.json | 1 + cvat-ui/src/actions/annotation-actions.ts | 126 +++++- cvat-ui/src/actions/boundaries-actions.ts | 6 +- cvat-ui/src/assets/brain.svg | 56 +++ .../annotation-page/appearance-block.tsx | 13 +- .../attribute-editor.tsx | 3 +- .../controls-side-bar/controls-side-bar.tsx | 2 +- .../controls-side-bar/controls-side-bar.tsx | 2 +- .../components/annotation-page/styles.scss | 46 +++ .../shortcuts-select.tsx | 2 +- .../tag-annotation-sidebar.tsx | 2 +- .../annotation-page/top-bar/left-group.tsx | 6 +- .../annotation-page/top-bar/right-group.tsx | 119 +++++- .../annotation-page/top-bar/top-bar.tsx | 15 +- .../change-password-form.tsx | 8 +- .../change-password-modal.tsx | 6 +- .../create-project-content.tsx | 78 +++- .../create-project-page.tsx | 53 ++- .../create-project.context.ts | 31 ++ .../canvas/canvas-context-menu.tsx | 2 +- .../annotation-page/canvas/canvas-wrapper.tsx | 2 +- .../controls-side-bar/controls-side-bar.tsx | 2 +- .../standard-workspace/propagate-confirm.tsx | 10 +- .../annotation-page/top-bar/top-bar.tsx | 29 +- .../containers/file-manager/file-manager.tsx | 6 +- .../containers/models-page/models-page.tsx | 6 +- .../src/containers/tasks-page/tasks-page.tsx | 8 +- cvat-ui/src/icons.tsx | 2 + cvat-ui/src/reducers/annotation-reducer.ts | 54 ++- cvat-ui/src/reducers/interfaces.ts | 18 + cvat-ui/src/reducers/notifications-reducer.ts | 18 + cvat-ui/src/reducers/plugins-reducer.ts | 3 +- .../engine/migrations/0039_auto_training.py | 48 +++ cvat/apps/engine/models.py | 49 ++- cvat/apps/engine/serializers.py | 31 +- cvat/apps/engine/urls.py | 2 + cvat/apps/engine/views.py | 13 +- cvat/apps/training/__init__.py | 1 + cvat/apps/training/apis.py | 362 ++++++++++++++++++ cvat/apps/training/apps.py | 11 + cvat/apps/training/jobs.py | 186 +++++++++ cvat/apps/training/signals.py | 30 ++ cvat/apps/training/urls.py | 11 + cvat/apps/training/views.py | 68 ++++ cvat/settings/base.py | 5 + cvat/settings/testing.py | 2 +- cvat/urls.py | 3 + docker-compose.yml | 1 + 54 files changed, 1743 insertions(+), 82 deletions(-) create mode 100644 cvat-ui/src/assets/brain.svg create mode 100644 cvat-ui/src/components/create-project-page/create-project.context.ts create mode 100644 cvat/apps/engine/migrations/0039_auto_training.py create mode 100644 cvat/apps/training/__init__.py create mode 100644 cvat/apps/training/apis.py create mode 100644 cvat/apps/training/apps.py create mode 100644 cvat/apps/training/jobs.py create mode 100644 cvat/apps/training/signals.py create mode 100644 cvat/apps/training/urls.py create mode 100644 cvat/apps/training/views.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 5718c4b7c14d..cb78ca045e65 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,5 @@ { - "python.pythonPath": ".env/bin/python", + "eslint.enable": true, "eslint.probe": [ "javascript", "typescript", diff --git a/cvat-core/src/api.js b/cvat-core/src/api.js index 22dc4858e538..a5d36b6ce833 100644 --- a/cvat-core/src/api.js +++ b/cvat-core/src/api.js @@ -573,7 +573,7 @@ function build() { * @param {module:API.cvat.classes.Task} task task to be annotated * @param {module:API.cvat.classes.MLModel} model model used to get annotation * @param {object} [args] extra arguments - * @returns {string} requestID + * @returns {object[]} annotations * @throws {module:API.cvat.exceptions.ServerError} * @throws {module:API.cvat.exceptions.PluginError} * @throws {module:API.cvat.exceptions.ArgumentError} diff --git a/cvat-core/src/project.js b/cvat-core/src/project.js index f389205c6f48..f4609d9d0a00 100644 --- a/cvat-core/src/project.js +++ b/cvat-core/src/project.js @@ -33,6 +33,7 @@ created_date: undefined, updated_date: undefined, task_subsets: undefined, + training_project: undefined, }; for (const property in data) { @@ -64,6 +65,9 @@ } data.task_subsets = Array.from(subsetsSet); } + if (initialData.training_project) { + data.training_project = JSON.parse(JSON.stringify(initialData.training_project)); + } Object.defineProperties( this, @@ -94,6 +98,7 @@ data.name = value; }, }, + /** * @name status * @type {module:API.cvat.enums.TaskStatus} @@ -217,9 +222,21 @@ subsets: { get: () => [...data.task_subsets], }, + _internalData: { get: () => data, }, + + training_project: { + get: () => data.training_project, + set: (training) => { + if (training) { + data.training_project = JSON.parse(JSON.stringify(training)); + } else { + data.training_project = training; + } + }, + }, }), ); } @@ -261,12 +278,17 @@ }; Project.prototype.save.implementation = async function () { + let trainingProject; + if (this.training_project) { + trainingProject = JSON.parse(JSON.stringify(this.training_project)); + } if (typeof this.id !== 'undefined') { const projectData = { name: this.name, assignee_id: this.assignee ? this.assignee.id : null, bug_tracker: this.bugTracker, labels: [...this._internalData.labels.map((el) => el.toJSON())], + training_project: trainingProject, }; await serverProxy.projects.save(this.id, projectData); @@ -276,6 +298,7 @@ const projectSpec = { name: this.name, labels: [...this.labels.map((el) => el.toJSON())], + training_project: trainingProject, }; if (this.bugTracker) { diff --git a/cvat-core/src/server-proxy.js b/cvat-core/src/server-proxy.js index 1e7018c53bc4..f3627d4ec2f8 100644 --- a/cvat-core/src/server-proxy.js +++ b/cvat-core/src/server-proxy.js @@ -9,6 +9,31 @@ const config = require('./config'); const DownloadWorker = require('./download.worker'); + function waitFor(frequencyHz, predicate) { + return new Promise((resolve, reject) => { + if (typeof predicate !== 'function') { + reject(new Error(`Predicate must be a function, got ${typeof predicate}`)); + } + + const internalWait = () => { + let result = false; + try { + result = predicate(); + } catch (error) { + reject(error); + } + + if (result) { + resolve(); + } else { + setTimeout(internalWait, 1000 / frequencyHz); + } + }; + + setTimeout(internalWait); + }); + } + function generateError(errorData) { if (errorData.response) { const message = `${errorData.message}. ${JSON.stringify(errorData.response.data) || ''}.`; @@ -993,6 +1018,96 @@ } } + function predictorStatus(projectId) { + const { backendAPI } = config; + + return new Promise((resolve, reject) => { + async function request() { + try { + const response = await Axios.get(`${backendAPI}/predict/status?project=${projectId}`); + return response.data; + } catch (errorData) { + throw generateError(errorData); + } + } + + const timeoutCallback = async () => { + let data = null; + try { + data = await request(); + if (data.status === 'queued') { + setTimeout(timeoutCallback, 1000); + } else if (data.status === 'done') { + resolve(data); + } else { + throw new Error(`Unknown status was received "${data.status}"`); + } + } catch (error) { + reject(error); + } + }; + + setTimeout(timeoutCallback); + }); + } + + function predictAnnotations(taskId, frame) { + return new Promise((resolve, reject) => { + const { backendAPI } = config; + + async function request() { + try { + const response = await Axios.get( + `${backendAPI}/predict/frame?task=${taskId}&frame=${frame}`, + ); + return response.data; + } catch (errorData) { + throw generateError(errorData); + } + } + + const timeoutCallback = async () => { + let data = null; + try { + data = await request(); + if (data.status === 'queued') { + setTimeout(timeoutCallback, 1000); + } else if (data.status === 'done') { + predictAnnotations.latestRequest.fetching = false; + resolve(data.annotation); + } else { + throw new Error(`Unknown status was received "${data.status}"`); + } + } catch (error) { + predictAnnotations.latestRequest.fetching = false; + reject(error); + } + }; + + const closureId = Date.now(); + predictAnnotations.latestRequest.id = closureId; + const predicate = () => !predictAnnotations.latestRequest.fetching || predictAnnotations.latestRequest.id !== closureId; + if (predictAnnotations.latestRequest.fetching) { + waitFor(5, predicate).then(() => { + if (predictAnnotations.latestRequest.id !== closureId) { + resolve(null); + } else { + predictAnnotations.latestRequest.fetching = true; + setTimeout(timeoutCallback); + } + }); + } else { + predictAnnotations.latestRequest.fetching = true; + setTimeout(timeoutCallback); + } + }); + } + + predictAnnotations.latestRequest = { + fetching: false, + id: null, + }; + async function installedApps() { const { backendAPI } = config; try { @@ -1123,6 +1238,14 @@ }), writable: false, }, + + predictor: { + value: Object.freeze({ + status: predictorStatus, + predict: predictAnnotations, + }), + writable: false, + }, }), ); } diff --git a/cvat-core/src/session.js b/cvat-core/src/session.js index 105dfc9b5b8c..daaee04a21ab 100644 --- a/cvat-core/src/session.js +++ b/cvat-core/src/session.js @@ -10,7 +10,7 @@ const { getFrame, getRanges, getPreview, clear: clearFrames, getContextImage, } = require('./frames'); - const { ArgumentError } = require('./exceptions'); + const { ArgumentError, DataError } = require('./exceptions'); const { TaskStatus } = require('./enums'); const { Label } = require('./labels'); const User = require('./user'); @@ -258,6 +258,19 @@ }, writable: true, }), + predictor: Object.freeze({ + value: { + async status() { + const result = await PluginRegistry.apiWrapper.call(this, prototype.predictor.status); + return result; + }, + async predict(frame) { + const result = await PluginRegistry.apiWrapper.call(this, prototype.predictor.predict, frame); + return result; + }, + }, + writable: true, + }), }); } @@ -665,6 +678,40 @@ * @instance * @async */ + /** + * @typedef {Object} PredictorStatus + * @property {string} message - message for a user to be displayed somewhere + * @property {number} projectScore - model accuracy + * @global + */ + /** + * Namespace is used for an interaction with events + * @namespace predictor + * @memberof Session + */ + /** + * Subscribe to updates of a ML model binded to the project + * @method status + * @memberof Session.predictor + * @throws {module:API.cvat.exceptions.PluginError} + * @throws {module:API.cvat.exceptions.ServerError} + * @returns {PredictorStatus} + * @instance + * @async + */ + /** + * Get predictions from a ML model binded to the project + * @method predict + * @memberof Session.predictor + * @param {number} frame - number of frame to inference + * @throws {module:API.cvat.exceptions.PluginError} + * @throws {module:API.cvat.exceptions.ArgumentError} + * @throws {module:API.cvat.exceptions.ServerError} + * @throws {module:API.cvat.exceptions.DataError} + * @returns {object[] | null} annotations + * @instance + * @async + */ } } @@ -865,6 +912,11 @@ this.logger = { log: Object.getPrototypeOf(this).logger.log.bind(this), }; + + this.predictor = { + status: Object.getPrototypeOf(this).predictor.status.bind(this), + predict: Object.getPrototypeOf(this).predictor.predict.bind(this), + }; } /** @@ -1554,6 +1606,11 @@ this.logger = { log: Object.getPrototypeOf(this).logger.log.bind(this), }; + + this.predictor = { + status: Object.getPrototypeOf(this).predictor.status.bind(this), + predict: Object.getPrototypeOf(this).predictor.predict.bind(this), + }; } /** @@ -1741,6 +1798,11 @@ return rangesData; }; + Job.prototype.frames.preview.implementation = async function () { + const frameData = await getPreview(this.task.id); + return frameData; + }; + // TODO: Check filter for annotations Job.prototype.annotations.get.implementation = async function (frame, allTracks, filters) { if (!Array.isArray(filters)) { @@ -1897,6 +1959,16 @@ return result; }; + Job.prototype.predictor.status.implementation = async function () { + const result = await this.task.predictor.status(); + return result; + }; + + Job.prototype.predictor.predict.implementation = async function (frame) { + const result = await this.task.predictor.predict(frame); + return result; + }; + Task.prototype.close.implementation = function closeTask() { clearFrames(this.id); for (const job of this.jobs) { @@ -2028,11 +2100,6 @@ return result; }; - Job.prototype.frames.preview.implementation = async function () { - const frameData = await getPreview(this.task.id); - return frameData; - }; - Task.prototype.frames.ranges.implementation = async function () { const rangesData = await getRanges(this.id); return rangesData; @@ -2199,6 +2266,39 @@ return result; }; + Task.prototype.predictor.status.implementation = async function () { + if (!Number.isInteger(this.projectId)) { + throw new DataError('The task must belong to a project to use the feature'); + } + + const result = await serverProxy.predictor.status(this.projectId); + return { + message: result.message, + progress: result.progress, + projectScore: result.score, + timeRemaining: result.time_remaining, + mediaAmount: result.media_amount, + annotationAmount: result.annotation_amount, + }; + }; + + Task.prototype.predictor.predict.implementation = async function (frame) { + if (!Number.isInteger(frame) || frame < 0) { + throw new ArgumentError(`Frame must be a positive integer. Got: "${frame}"`); + } + + if (frame >= this.size) { + throw new ArgumentError(`The frame with number ${frame} is out of the task`); + } + + if (!Number.isInteger(this.projectId)) { + throw new DataError('The task must belong to a project to use the feature'); + } + + const result = await serverProxy.predictor.predict(this.id, frame); + return result; + }; + Job.prototype.frames.contextImage.implementation = async function (taskId, frameId) { const result = await getContextImage(taskId, frameId); return result; diff --git a/cvat-ui/package-lock.json b/cvat-ui/package-lock.json index c23090f3cb8c..27c93845c188 100644 --- a/cvat-ui/package-lock.json +++ b/cvat-ui/package-lock.json @@ -28953,6 +28953,11 @@ "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.11.0.tgz", "integrity": "sha512-gbBVYR2p8mnriqAwWx9LbuUrShnAuSCNnuPGyc7GJrMVQtPDAh8iLpv7FRuMPFb56KkaVZIYSz1PrjI9q0QPCw==" }, + "react-moment": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/react-moment/-/react-moment-1.1.1.tgz", + "integrity": "sha512-WjwvxBSnmLMRcU33do0KixDB+9vP3e84eCse+rd+HNklAMNWyRgZTDEQlay/qK6lcXFPRuEIASJTpEt6pyK7Ww==" + }, "react-redux": { "version": "7.2.2", "resolved": "https://registry.npmjs.org/react-redux/-/react-redux-7.2.2.tgz", diff --git a/cvat-ui/package.json b/cvat-ui/package.json index d6bbde46e5cc..caf6ba255fde 100644 --- a/cvat-ui/package.json +++ b/cvat-ui/package.json @@ -77,6 +77,7 @@ "react-color": "^2.19.3", "react-cookie": "^4.0.3", "react-dom": "^16.14.0", + "react-moment": "^1.1.1", "react-redux": "^7.2.2", "react-resizable": "^1.11.1", "@types/react-resizable": "^1.7.2", diff --git a/cvat-ui/src/actions/annotation-actions.ts b/cvat-ui/src/actions/annotation-actions.ts index c44b5bb3b56a..8e5fd184df86 100644 --- a/cvat-ui/src/actions/annotation-actions.ts +++ b/cvat-ui/src/actions/annotation-actions.ts @@ -1,4 +1,4 @@ -// Copyright (C) 2021 Intel Corporation +// Copyright (C) 2020-2021 Intel Corporation // // SPDX-License-Identifier: MIT @@ -190,6 +190,10 @@ export enum AnnotationActionTypes { SWITCH_REQUEST_REVIEW_DIALOG = 'SWITCH_REQUEST_REVIEW_DIALOG', SWITCH_SUBMIT_REVIEW_DIALOG = 'SWITCH_SUBMIT_REVIEW_DIALOG', SET_FORCE_EXIT_ANNOTATION_PAGE_FLAG = 'SET_FORCE_EXIT_ANNOTATION_PAGE_FLAG', + UPDATE_PREDICTOR_STATE = 'UPDATE_PREDICTOR_STATE', + GET_PREDICTIONS = 'GET_PREDICTIONS', + GET_PREDICTIONS_FAILED = 'GET_PREDICTIONS_FAILED', + GET_PREDICTIONS_SUCCESS = 'GET_PREDICTIONS_SUCCESS', HIDE_SHOW_CONTEXT_IMAGE = 'HIDE_SHOW_CONTEXT_IMAGE', GET_CONTEXT_IMAGE = 'GET_CONTEXT_IMAGE', } @@ -612,6 +616,87 @@ export function switchPlay(playing: boolean): AnyAction { }; } +export function getPredictionsAsync(): ThunkAction { + return async (dispatch: ActionCreator): Promise => { + const { + annotations: { + states: currentStates, + zLayer: { cur: curZOrder }, + }, + predictor: { enabled, annotatedFrames }, + } = getStore().getState().annotation; + + const { + filters, frame, showAllInterpolationTracks, jobInstance: job, + } = receiveAnnotationsParameters(); + if (!enabled || currentStates.length || annotatedFrames.includes(frame)) return; + + dispatch({ + type: AnnotationActionTypes.GET_PREDICTIONS, + payload: {}, + }); + + let annotations = []; + try { + annotations = await job.predictor.predict(frame); + // current frame could be changed during a request above, need to fetch it from store again + const { number: currentFrame } = getStore().getState().annotation.player.frame; + if (frame !== currentFrame || annotations === null) { + // another request has already been sent or user went to another frame + // we do not need dispatch predictions success action + return; + } + annotations = annotations.map( + (data: any): any => + new cvat.classes.ObjectState({ + shapeType: data.type, + label: job.task.labels.filter((label: any): boolean => label.id === data.label)[0], + points: data.points, + objectType: ObjectType.SHAPE, + frame, + occluded: false, + source: 'auto', + attributes: {}, + zOrder: curZOrder, + }), + ); + + dispatch({ + type: AnnotationActionTypes.GET_PREDICTIONS_SUCCESS, + payload: { frame }, + }); + } catch (error) { + dispatch({ + type: AnnotationActionTypes.GET_PREDICTIONS_FAILED, + payload: { + error, + }, + }); + } + + try { + await job.annotations.put(annotations); + const states = await job.annotations.get(frame, showAllInterpolationTracks, filters); + const history = await job.actions.get(); + + dispatch({ + type: AnnotationActionTypes.CREATE_ANNOTATIONS_SUCCESS, + payload: { + states, + history, + }, + }); + } catch (error) { + dispatch({ + type: AnnotationActionTypes.CREATE_ANNOTATIONS_FAILED, + payload: { + error, + }, + }); + } + }; +} + export function changeFrameAsync(toFrame: number, fillBuffer?: boolean, frameStep?: number): ThunkAction { return async (dispatch: ActionCreator): Promise => { const state: CombinedState = getStore().getState(); @@ -689,6 +774,7 @@ export function changeFrameAsync(toFrame: number, fillBuffer?: boolean, frameSte delay, }, }); + dispatch(getPredictionsAsync()); } catch (error) { if (error !== 'not needed') { dispatch({ @@ -934,9 +1020,11 @@ export function getJobAsync(tid: number, jid: number, initialFrame: number, init loadJobEvent.close(await jobInfoGenerator(job)); + const openTime = Date.now(); dispatch({ type: AnnotationActionTypes.GET_JOB_SUCCESS, payload: { + openTime, job, issues, reviews, @@ -950,10 +1038,38 @@ export function getJobAsync(tid: number, jid: number, initialFrame: number, init maxZ, }, }); + if (job.task.dimension === DimensionType.DIM_3D) { const workspace = Workspace.STANDARD3D; dispatch(changeWorkspace(workspace)); } + + const updatePredictorStatus = async (): Promise => { + // get current job + const currentState: CombinedState = getStore().getState(); + const { openTime: currentOpenTime, instance: currentJob } = currentState.annotation.job; + if (currentJob === null || currentJob.id !== job.id || currentOpenTime !== openTime) { + // the job was closed, changed or reopened + return; + } + + try { + const status = await job.predictor.status(); + dispatch({ + type: AnnotationActionTypes.UPDATE_PREDICTOR_STATE, + payload: status, + }); + setTimeout(updatePredictorStatus, 60 * 1000); + } catch (error) { + dispatch({ + type: AnnotationActionTypes.UPDATE_PREDICTOR_STATE, + payload: { error }, + }); + setTimeout(updatePredictorStatus, 20 * 1000); + } + }; + updatePredictorStatus(); + dispatch(changeFrameAsync(frameNumber, false)); } catch (error) { dispatch({ @@ -1516,6 +1632,14 @@ export function setForceExitAnnotationFlag(forceExit: boolean): AnyAction { }; } +export function switchPredictor(predictorEnabled: boolean): AnyAction { + return { + type: AnnotationActionTypes.UPDATE_PREDICTOR_STATE, + payload: { + enabled: predictorEnabled, + }, + }; +} export function hideShowContextImage(hidden: boolean): AnyAction { return { type: AnnotationActionTypes.HIDE_SHOW_CONTEXT_IMAGE, diff --git a/cvat-ui/src/actions/boundaries-actions.ts b/cvat-ui/src/actions/boundaries-actions.ts index 0e22f60958a4..5da395d8c1f3 100644 --- a/cvat-ui/src/actions/boundaries-actions.ts +++ b/cvat-ui/src/actions/boundaries-actions.ts @@ -1,8 +1,10 @@ -// Copyright (C) 2020 Intel Corporation +// Copyright (C) 2020-2021 Intel Corporation // // SPDX-License-Identifier: MIT -import { ActionUnion, createAction, ThunkAction, ThunkDispatch } from 'utils/redux'; +import { + ActionUnion, createAction, ThunkAction, ThunkDispatch, +} from 'utils/redux'; import getCore from 'cvat-core-wrapper'; import { LogType } from 'cvat-logger'; import { computeZRange } from './annotation-actions'; diff --git a/cvat-ui/src/assets/brain.svg b/cvat-ui/src/assets/brain.svg new file mode 100644 index 000000000000..4aebe4071219 --- /dev/null +++ b/cvat-ui/src/assets/brain.svg @@ -0,0 +1,56 @@ + + + + + + + + + + + + diff --git a/cvat-ui/src/components/annotation-page/appearance-block.tsx b/cvat-ui/src/components/annotation-page/appearance-block.tsx index 8ea48cf3e786..a00a4937a063 100644 --- a/cvat-ui/src/components/annotation-page/appearance-block.tsx +++ b/cvat-ui/src/components/annotation-page/appearance-block.tsx @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Intel Corporation +// Copyright (C) 2020-2021 Intel Corporation // // SPDX-License-Identifier: MIT @@ -10,6 +10,7 @@ import Radio, { RadioChangeEvent } from 'antd/lib/radio'; import Slider from 'antd/lib/slider'; import Checkbox, { CheckboxChangeEvent } from 'antd/lib/checkbox'; import Collapse from 'antd/lib/collapse'; +import Button from 'antd/lib/button'; import ColorPicker from 'components/annotation-page/standard-workspace/objects-side-bar/color-picker'; import { ColorizeIcon } from 'icons'; @@ -26,7 +27,6 @@ import { changeShowBitmap as changeShowBitmapAction, changeShowProjections as changeShowProjectionsAction, } from 'actions/settings-actions'; -import Button from 'antd/lib/button'; interface StateToProps { appearanceCollapsed: boolean; @@ -152,7 +152,14 @@ function AppearanceBlock(props: Props): JSX.Element { activeKey={appearanceCollapsed ? [] : ['appearance']} className='cvat-objects-appearance-collapse' > - Appearance} key='appearance'> + + Appearance + + )} + key='appearance' + >
Color by span { + > svg { + fill: $inprogress-progress-color; + } + } + } + + &.cvat-predictor-fetching { + > span { + > svg { + animation-duration: 500ms; + animation-name: predictorBlinking; + animation-iteration-count: infinite; + + @keyframes predictorBlinking { + 0% { + fill: $inprogress-progress-color; + } + + 50% { + fill: $completed-progress-color; + } + + 100% { + fill: $inprogress-progress-color; + } + } + } + } + } + + &.cvat-predictor-disabled { + opacity: 0.5; + + &:active { + pointer-events: none; + } + + > span[role='img'] { + transform: scale(0.8) !important; + } + } +} + .cvat-annotation-disabled-header-button { @extend .cvat-annotation-header-button; diff --git a/cvat-ui/src/components/annotation-page/tag-annotation-workspace/tag-annotation-sidebar/shortcuts-select.tsx b/cvat-ui/src/components/annotation-page/tag-annotation-workspace/tag-annotation-sidebar/shortcuts-select.tsx index 1ec428489456..1c495a15344d 100644 --- a/cvat-ui/src/components/annotation-page/tag-annotation-workspace/tag-annotation-sidebar/shortcuts-select.tsx +++ b/cvat-ui/src/components/annotation-page/tag-annotation-workspace/tag-annotation-sidebar/shortcuts-select.tsx @@ -4,12 +4,12 @@ import React, { useState, useEffect } from 'react'; import { useSelector } from 'react-redux'; -import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react'; import { Row, Col } from 'antd/lib/grid'; import Text from 'antd/lib/typography/Text'; import Select from 'antd/lib/select'; import { CombinedState } from 'reducers/interfaces'; +import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react'; import { shift } from 'utils/math'; interface ShortcutLabelMap { diff --git a/cvat-ui/src/components/annotation-page/tag-annotation-workspace/tag-annotation-sidebar/tag-annotation-sidebar.tsx b/cvat-ui/src/components/annotation-page/tag-annotation-workspace/tag-annotation-sidebar/tag-annotation-sidebar.tsx index f54b405e7643..7b5e44563d44 100644 --- a/cvat-ui/src/components/annotation-page/tag-annotation-workspace/tag-annotation-sidebar/tag-annotation-sidebar.tsx +++ b/cvat-ui/src/components/annotation-page/tag-annotation-workspace/tag-annotation-sidebar/tag-annotation-sidebar.tsx @@ -20,11 +20,11 @@ import { changeFrameAsync, rememberObject, } from 'actions/annotation-actions'; -import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react'; import { Canvas } from 'cvat-canvas-wrapper'; import { CombinedState, ObjectType } from 'reducers/interfaces'; import LabelSelector from 'components/label-selector/label-selector'; import getCore from 'cvat-core-wrapper'; +import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react'; import ShortcutsSelect from './shortcuts-select'; const cvat = getCore(); diff --git a/cvat-ui/src/components/annotation-page/top-bar/left-group.tsx b/cvat-ui/src/components/annotation-page/top-bar/left-group.tsx index 49d9e1ff2244..dbbad65fed56 100644 --- a/cvat-ui/src/components/annotation-page/top-bar/left-group.tsx +++ b/cvat-ui/src/components/annotation-page/top-bar/left-group.tsx @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Intel Corporation +// Copyright (C) 2020-2021 Intel Corporation // // SPDX-License-Identifier: MIT @@ -11,7 +11,9 @@ import Timeline from 'antd/lib/timeline'; import Dropdown from 'antd/lib/dropdown'; import AnnotationMenuContainer from 'containers/annotation-page/top-bar/annotation-menu'; -import { MainMenuIcon, SaveIcon, UndoIcon, RedoIcon } from 'icons'; +import { + MainMenuIcon, SaveIcon, UndoIcon, RedoIcon, +} from 'icons'; interface Props { saving: boolean; diff --git a/cvat-ui/src/components/annotation-page/top-bar/right-group.tsx b/cvat-ui/src/components/annotation-page/top-bar/right-group.tsx index 9d863f8c547a..b91e575fd905 100644 --- a/cvat-ui/src/components/annotation-page/top-bar/right-group.tsx +++ b/cvat-ui/src/components/annotation-page/top-bar/right-group.tsx @@ -7,28 +7,141 @@ import { Col } from 'antd/lib/grid'; import Icon from '@ant-design/icons'; import Select from 'antd/lib/select'; import Button from 'antd/lib/button'; +import Text from 'antd/lib/typography/Text'; +import Tooltip from 'antd/lib/tooltip'; +import Moment from 'react-moment'; + +import moment from 'moment'; import { useSelector } from 'react-redux'; -import { FilterIcon, FullscreenIcon, InfoIcon } from 'icons'; -import { CombinedState, DimensionType, Workspace } from 'reducers/interfaces'; +import { + FilterIcon, FullscreenIcon, InfoIcon, BrainIcon, +} from 'icons'; +import { + CombinedState, DimensionType, Workspace, PredictorState, +} from 'reducers/interfaces'; interface Props { workspace: Workspace; + predictor: PredictorState; + isTrainingActive: boolean; + showStatistics(): void; + + switchPredictor(predictorEnabled: boolean): void; + showFilters(): void; changeWorkspace(workspace: Workspace): void; + jobInstance: any; } function RightGroup(props: Props): JSX.Element { const { - showFilters, showStatistics, changeWorkspace, workspace, jobInstance, + showStatistics, + changeWorkspace, + switchPredictor, + workspace, + predictor, + jobInstance, + isTrainingActive, + showFilters, } = props; + predictor.annotationAmount = predictor.annotationAmount ? predictor.annotationAmount : 0; + predictor.mediaAmount = predictor.mediaAmount ? predictor.mediaAmount : 0; + const formattedScore = `${(predictor.projectScore * 100).toFixed(0)}%`; + const predictorTooltip = ( +
+ Adaptive auto annotation is + {predictor.enabled ? ( + + {' active'} + + ) : ( + + {' inactive'} + + )} +
+ + Annotations amount: + {predictor.annotationAmount} + +
+ + Media amount: + {predictor.mediaAmount} + +
+ {predictor.annotationAmount > 0 ? ( + + Model mAP is + {' '} + {formattedScore} +
+
+ ) : null} + {predictor.error ? ( + + {predictor.error.toString()} +
+
+ ) : null} + {predictor.message ? ( + + Status: + {' '} + {predictor.message} +
+
+ ) : null} + {predictor.timeRemaining > 0 ? ( + + Time Remaining: + {' '} + +
+
+ ) : null} + {predictor.progress > 0 ? ( + + Progress: + {predictor.progress.toFixed(1)} + {' '} + % + + ) : null} +
+ ); + + let predictorClassName = 'cvat-annotation-header-button cvat-predictor-button'; + if (!!predictor.error || !predictor.projectScore) { + predictorClassName += ' cvat-predictor-disabled'; + } else if (predictor.enabled) { + if (predictor.fetching) { + predictorClassName += ' cvat-predictor-fetching'; + } + predictorClassName += ' cvat-predictor-inprogress'; + } const filters = useSelector((state: CombinedState) => state.annotation.annotations.filters); return ( + {isTrainingActive && ( + + )} ; @@ -102,7 +162,16 @@ export default function CreateProjectContent(): JSX.Element { if (nameFormRef.current && advancedFormRef.current) { const basicValues = await nameFormRef.current.validateFields(); const advancedValues = await advancedFormRef.current.validateFields(); + const adaptiveAutoAnnotationValues = await adaptiveAutoAnnotationFormRef.current?.validateFields(); projectData.name = basicValues.name; + projectData.training_project = null; + if (adaptiveAutoAnnotationValues) { + projectData.training_project = {}; + for (const [field, value] of Object.entries(adaptiveAutoAnnotationValues)) { + projectData.training_project[field] = value; + } + } + for (const [field, value] of Object.entries(advancedValues)) { projectData[field] = value; } @@ -120,6 +189,11 @@ export default function CreateProjectContent(): JSX.Element { + {isTrainingActive.value && ( + + + + )} Labels: - - Create a new project - - - + + + + Create a new project + + + + ); } + +interface StateToProps { + isTrainingActive: boolean; +} + +function mapStateToProps(state: CombinedState): StateToProps { + return { + isTrainingActive: state.plugins.list.PREDICT, + }; +} + +export default connect(mapStateToProps)(CreateProjectPageComponent); diff --git a/cvat-ui/src/components/create-project-page/create-project.context.ts b/cvat-ui/src/components/create-project-page/create-project.context.ts new file mode 100644 index 000000000000..d283658a3078 --- /dev/null +++ b/cvat-ui/src/components/create-project-page/create-project.context.ts @@ -0,0 +1,31 @@ +// Copyright (C) 2020-2021 Intel Corporation +// +// SPDX-License-Identifier: MIT +import { createContext, Dispatch, SetStateAction } from 'react'; + +export interface IState { + value: T; + set?: Dispatch>; +} + +export function getDefaultState(v: T): IState { + return { + value: v, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + set: (value: SetStateAction): void => {}, + }; +} + +export interface ICreateProjectContext { + projectClass: IState; + trainingEnabled: IState; + isTrainingActive: IState; +} + +export const defaultState: ICreateProjectContext = { + projectClass: getDefaultState(''), + trainingEnabled: getDefaultState(false), + isTrainingActive: getDefaultState(false), +}; + +export default createContext(defaultState); diff --git a/cvat-ui/src/containers/annotation-page/canvas/canvas-context-menu.tsx b/cvat-ui/src/containers/annotation-page/canvas/canvas-context-menu.tsx index b25b297c6cc9..b46020e161b7 100644 --- a/cvat-ui/src/containers/annotation-page/canvas/canvas-context-menu.tsx +++ b/cvat-ui/src/containers/annotation-page/canvas/canvas-context-menu.tsx @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Intel Corporation +// Copyright (C) 2020-2021 Intel Corporation // // SPDX-License-Identifier: MIT diff --git a/cvat-ui/src/containers/annotation-page/canvas/canvas-wrapper.tsx b/cvat-ui/src/containers/annotation-page/canvas/canvas-wrapper.tsx index 4581c4367997..70e94de3f334 100644 --- a/cvat-ui/src/containers/annotation-page/canvas/canvas-wrapper.tsx +++ b/cvat-ui/src/containers/annotation-page/canvas/canvas-wrapper.tsx @@ -3,8 +3,8 @@ // SPDX-License-Identifier: MIT import { connect } from 'react-redux'; - import { KeyMap } from 'utils/mousetrap-react'; + import CanvasWrapperComponent from 'components/annotation-page/canvas/canvas-wrapper'; import { confirmCanvasReady, diff --git a/cvat-ui/src/containers/annotation-page/review-workspace/controls-side-bar/controls-side-bar.tsx b/cvat-ui/src/containers/annotation-page/review-workspace/controls-side-bar/controls-side-bar.tsx index 2125213a2907..44f984423c3e 100644 --- a/cvat-ui/src/containers/annotation-page/review-workspace/controls-side-bar/controls-side-bar.tsx +++ b/cvat-ui/src/containers/annotation-page/review-workspace/controls-side-bar/controls-side-bar.tsx @@ -2,7 +2,6 @@ // // SPDX-License-Identifier: MIT -import { KeyMap } from 'utils/mousetrap-react'; import { connect } from 'react-redux'; import { Canvas } from 'cvat-canvas-wrapper'; @@ -19,6 +18,7 @@ import { } from 'actions/annotation-actions'; import ControlsSideBarComponent from 'components/annotation-page/review-workspace/controls-side-bar/controls-side-bar'; import { ActiveControl, CombinedState, Rotation } from 'reducers/interfaces'; +import { KeyMap } from 'utils/mousetrap-react'; interface StateToProps { canvasInstance: Canvas; diff --git a/cvat-ui/src/containers/annotation-page/standard-workspace/propagate-confirm.tsx b/cvat-ui/src/containers/annotation-page/standard-workspace/propagate-confirm.tsx index 7573855e2873..891dc3054f1f 100644 --- a/cvat-ui/src/containers/annotation-page/standard-workspace/propagate-confirm.tsx +++ b/cvat-ui/src/containers/annotation-page/standard-workspace/propagate-confirm.tsx @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Intel Corporation +// Copyright (C) 2020-2021 Intel Corporation // // SPDX-License-Identifier: MIT @@ -68,7 +68,9 @@ function mapDispatchToProps(dispatch: any): DispatchToProps { type Props = StateToProps & DispatchToProps; class PropagateConfirmContainer extends React.PureComponent { private propagateObject = (): void => { - const { propagateObject, objectState, propagateFrames, frameNumber, stopFrame, jobInstance } = this.props; + const { + propagateObject, objectState, propagateFrames, frameNumber, stopFrame, jobInstance, + } = this.props; const propagateUpToFrame = Math.min(frameNumber + propagateFrames, stopFrame); propagateObject(jobInstance, objectState, frameNumber + 1, propagateUpToFrame); @@ -87,7 +89,9 @@ class PropagateConfirmContainer extends React.PureComponent { }; public render(): JSX.Element { - const { frameNumber, stopFrame, propagateFrames, cancel, objectState } = this.props; + const { + frameNumber, stopFrame, propagateFrames, cancel, objectState, + } = this.props; const propagateUpToFrame = Math.min(frameNumber + propagateFrames, stopFrame); diff --git a/cvat-ui/src/containers/annotation-page/top-bar/top-bar.tsx b/cvat-ui/src/containers/annotation-page/top-bar/top-bar.tsx index ceae69c39ca6..d2fcc3808168 100644 --- a/cvat-ui/src/containers/annotation-page/top-bar/top-bar.tsx +++ b/cvat-ui/src/containers/annotation-page/top-bar/top-bar.tsx @@ -18,6 +18,8 @@ import { searchAnnotationsAsync, searchEmptyFrameAsync, setForceExitAnnotationFlag as setForceExitAnnotationFlagAction, + switchPredictor as switchPredictorAction, + getPredictionsAsync, showFilters as showFiltersAction, showStatistics as showStatisticsAction, switchPlay, @@ -25,7 +27,9 @@ import { } from 'actions/annotation-actions'; import AnnotationTopBarComponent from 'components/annotation-page/top-bar/top-bar'; import { Canvas } from 'cvat-canvas-wrapper'; -import { CombinedState, FrameSpeed, Workspace } from 'reducers/interfaces'; +import { + CombinedState, FrameSpeed, Workspace, PredictorState, +} from 'reducers/interfaces'; import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react'; interface StateToProps { @@ -48,6 +52,8 @@ interface StateToProps { normalizedKeyMap: Record; canvasInstance: Canvas; forceExit: boolean; + predictor: PredictorState; + isTrainingActive: boolean; } interface DispatchToProps { @@ -62,6 +68,7 @@ interface DispatchToProps { searchEmptyFrame(sessionInstance: any, frameFrom: number, frameTo: number): void; setForceExitAnnotationFlag(forceExit: boolean): void; changeWorkspace(workspace: Workspace): void; + switchPredictor(predictorEnabled: boolean): void; } function mapStateToProps(state: CombinedState): StateToProps { @@ -78,12 +85,14 @@ function mapStateToProps(state: CombinedState): StateToProps { job: { instance: jobInstance }, canvas: { ready: canvasIsReady, instance: canvasInstance }, workspace, + predictor, }, settings: { player: { frameSpeed, frameStep }, workspace: { autoSave, autoSaveInterval }, }, shortcuts: { keyMap, normalizedKeyMap }, + plugins: { list }, } = state; return { @@ -106,6 +115,8 @@ function mapStateToProps(state: CombinedState): StateToProps { normalizedKeyMap, canvasInstance, forceExit, + predictor, + isTrainingActive: list.PREDICT, }; } @@ -146,6 +157,12 @@ function mapDispatchToProps(dispatch: any): DispatchToProps { setForceExitAnnotationFlag(forceExit: boolean): void { dispatch(setForceExitAnnotationFlagAction(forceExit)); }, + switchPredictor(predictorEnabled: boolean): void { + dispatch(switchPredictorAction(predictorEnabled)); + if (predictorEnabled) { + dispatch(getPredictionsAsync()); + } + }, }; } @@ -497,11 +514,14 @@ class AnnotationTopBarContainer extends React.PureComponent { redoAction, workspace, canvasIsReady, - searchAnnotations, - changeWorkspace, keyMap, normalizedKeyMap, canvasInstance, + predictor, + searchAnnotations, + changeWorkspace, + switchPredictor, + isTrainingActive, } = this.props; const preventDefault = (event: KeyboardEvent | undefined): void => { @@ -611,6 +631,8 @@ class AnnotationTopBarContainer extends React.PureComponent { onInputChange={this.onChangePlayerInputValue} onURLIconClick={this.onURLIconClick} changeWorkspace={changeWorkspace} + switchPredictor={switchPredictor} + predictor={predictor} workspace={workspace} playing={playing} saving={saving} @@ -636,6 +658,7 @@ class AnnotationTopBarContainer extends React.PureComponent { onUndoClick={this.undo} onRedoClick={this.redo} jobInstance={jobInstance} + isTrainingActive={isTrainingActive} /> ); diff --git a/cvat-ui/src/containers/file-manager/file-manager.tsx b/cvat-ui/src/containers/file-manager/file-manager.tsx index 3db64770e166..7d08a3bd0997 100644 --- a/cvat-ui/src/containers/file-manager/file-manager.tsx +++ b/cvat-ui/src/containers/file-manager/file-manager.tsx @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Intel Corporation +// Copyright (C) 2020-2021 Intel Corporation // // SPDX-License-Identifier: MIT @@ -69,7 +69,9 @@ export class FileManagerContainer extends React.PureComponent { } public render(): JSX.Element { - const { treeData, getTreeData, withRemote, onChangeActiveKey } = this.props; + const { + treeData, getTreeData, withRemote, onChangeActiveKey, + } = this.props; return ( !task.instance.jobs.length).length - : 0, + numberOfHiddenTasks: tasks.hideEmpty ? + tasks.current.filter((task: Task): boolean => !task.instance.jobs.length).length : + 0, }; } diff --git a/cvat-ui/src/icons.tsx b/cvat-ui/src/icons.tsx index fc7d8e5ed976..9c0c1488e39e 100644 --- a/cvat-ui/src/icons.tsx +++ b/cvat-ui/src/icons.tsx @@ -47,6 +47,7 @@ import SVGCubeIcon from './assets/cube-icon.svg'; import SVGResetPerspectiveIcon from './assets/reset-perspective.svg'; import SVGColorizeIcon from './assets/colorize-icon.svg'; import SVGAITools from './assets/ai-tools-icon.svg'; +import SVGBrain from './assets/brain.svg'; import SVGOpenCV from './assets/opencv.svg'; import SVGFilterIcon from './assets/object-filter-icon.svg'; @@ -93,5 +94,6 @@ export const CubeIcon = React.memo((): JSX.Element => ); export const ResetPerspectiveIcon = React.memo((): JSX.Element => ); export const AIToolsIcon = React.memo((): JSX.Element => ); export const ColorizeIcon = React.memo((): JSX.Element => ); +export const BrainIcon = React.memo((): JSX.Element => ); export const OpenCVIcon = React.memo((): JSX.Element => ); export const FilterIcon = React.memo((): JSX.Element => ); diff --git a/cvat-ui/src/reducers/annotation-reducer.ts b/cvat-ui/src/reducers/annotation-reducer.ts index 8314c2877ea1..84682fb262dd 100644 --- a/cvat-ui/src/reducers/annotation-reducer.ts +++ b/cvat-ui/src/reducers/annotation-reducer.ts @@ -1,4 +1,4 @@ -// Copyright (C) 2021 Intel Corporation +// Copyright (C) 2020-2021 Intel Corporation // // SPDX-License-Identifier: MIT @@ -38,6 +38,7 @@ const defaultState: AnnotationState = { activeControl: ActiveControl.CURSOR, }, job: { + openTime: null, labels: [], requestedId: null, instance: null, @@ -108,6 +109,14 @@ const defaultState: AnnotationState = { requestReviewDialogVisible: false, submitReviewDialogVisible: false, tabContentHeight: 0, + predictor: { + enabled: false, + error: null, + message: '', + projectScore: 0, + fetching: false, + annotatedFrames: [], + }, workspace: Workspace.STANDARD, }; @@ -129,6 +138,7 @@ export default (state = defaultState, action: AnyAction): AnnotationState => { const { job, states, + openTime, frameNumber: number, frameFilename: filename, colors, @@ -148,6 +158,7 @@ export default (state = defaultState, action: AnyAction): AnnotationState => { ...state, job: { ...state.job, + openTime, fetching: false, instance: job, labels: job.task.labels, @@ -1093,6 +1104,47 @@ export default (state = defaultState, action: AnyAction): AnnotationState => { workspace, }; } + case AnnotationActionTypes.UPDATE_PREDICTOR_STATE: { + const { payload } = action; + return { + ...state, + predictor: { + ...state.predictor, + ...payload, + }, + }; + } + case AnnotationActionTypes.GET_PREDICTIONS: { + return { + ...state, + predictor: { + ...state.predictor, + fetching: true, + }, + }; + } + case AnnotationActionTypes.GET_PREDICTIONS_SUCCESS: { + const { frame } = action.payload; + const annotatedFrames = [...state.predictor.annotatedFrames, frame]; + + return { + ...state, + predictor: { + ...state.predictor, + fetching: false, + annotatedFrames, + }, + }; + } + case AnnotationActionTypes.GET_PREDICTIONS_FAILED: { + return { + ...state, + predictor: { + ...state.predictor, + fetching: false, + }, + }; + } case AnnotationActionTypes.RESET_CANVAS: { return { ...state, diff --git a/cvat-ui/src/reducers/interfaces.ts b/cvat-ui/src/reducers/interfaces.ts index 45af4585d14f..4defc78006f6 100644 --- a/cvat-ui/src/reducers/interfaces.ts +++ b/cvat-ui/src/reducers/interfaces.ts @@ -111,6 +111,7 @@ export enum SupportedPlugins { GIT_INTEGRATION = 'GIT_INTEGRATION', ANALYTICS = 'ANALYTICS', MODELS = 'MODELS', + PREDICT = 'PREDICT', } export type PluginsList = { @@ -301,6 +302,9 @@ export interface NotificationsState { commentingIssue: null | ErrorState; submittingReview: null | ErrorState; }; + predictor: { + prediction: null | ErrorState; + }; }; messages: { tasks: { @@ -367,6 +371,18 @@ export enum Rotation { CLOCKWISE90, } +export interface PredictorState { + timeRemaining: number; + progress: number; + projectScore: number; + message: string; + error: Error | null; + enabled: boolean; + fetching: boolean; + annotationAmount: number; + mediaAmount: number; +} + export interface AnnotationState { activities: { loads: { @@ -388,6 +404,7 @@ export interface AnnotationState { activeControl: ActiveControl; }; job: { + openTime: null | number; labels: any[]; requestedId: number | null; instance: any | null | undefined; @@ -462,6 +479,7 @@ export interface AnnotationState { appearanceCollapsed: boolean; tabContentHeight: number; workspace: Workspace; + predictor: PredictorState; aiToolsRef: MutableRefObject; } diff --git a/cvat-ui/src/reducers/notifications-reducer.ts b/cvat-ui/src/reducers/notifications-reducer.ts index 4df4f4461852..bc56f3ffdbaa 100644 --- a/cvat-ui/src/reducers/notifications-reducer.ts +++ b/cvat-ui/src/reducers/notifications-reducer.ts @@ -102,6 +102,9 @@ const defaultState: NotificationsState = { resolvingIssue: null, submittingReview: null, }, + predictor: { + prediction: null, + }, }, messages: { tasks: { @@ -1104,6 +1107,21 @@ export default function (state = defaultState, action: AnyAction): Notifications }, }; } + case AnnotationActionTypes.GET_PREDICTIONS_FAILED: { + return { + ...state, + errors: { + ...state.errors, + predictor: { + ...state.errors.predictor, + prediction: { + message: 'Could not fetch prediction data', + reason: action.payload.error, + }, + }, + }, + }; + } case BoundariesActionTypes.RESET_AFTER_ERROR: case AuthActionTypes.LOGOUT_SUCCESS: { return { ...defaultState }; diff --git a/cvat-ui/src/reducers/plugins-reducer.ts b/cvat-ui/src/reducers/plugins-reducer.ts index 85e6093c34dc..ad424238831e 100644 --- a/cvat-ui/src/reducers/plugins-reducer.ts +++ b/cvat-ui/src/reducers/plugins-reducer.ts @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Intel Corporation +// Copyright (C) 2020-2021 Intel Corporation // // SPDX-License-Identifier: MIT @@ -13,6 +13,7 @@ const defaultState: PluginsState = { GIT_INTEGRATION: false, ANALYTICS: false, MODELS: false, + PREDICT: false, }, }; diff --git a/cvat/apps/engine/migrations/0039_auto_training.py b/cvat/apps/engine/migrations/0039_auto_training.py new file mode 100644 index 000000000000..a9f22ea7a03a --- /dev/null +++ b/cvat/apps/engine/migrations/0039_auto_training.py @@ -0,0 +1,48 @@ +# Generated by Django 3.1.7 on 2021-04-02 13:17 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('engine', '0038_manifest'), + ] + + operations = [ + migrations.CreateModel( + name='TrainingProject', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('host', models.CharField(max_length=256)), + ('username', models.CharField(max_length=256)), + ('password', models.CharField(max_length=256)), + ('training_id', models.CharField(max_length=64)), + ('enabled', models.BooleanField(null=True)), + ('project_class', models.CharField(blank=True, choices=[('OD', 'Object Detection')], max_length=2, null=True)), + ], + ), + migrations.CreateModel( + name='TrainingProjectLabel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('training_label_id', models.CharField(max_length=64)), + ('cvat_label', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_project_label', to='engine.label')), + ], + ), + migrations.CreateModel( + name='TrainingProjectImage', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('idx', models.PositiveIntegerField()), + ('training_image_id', models.CharField(max_length=64)), + ('task', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='engine.task')), + ], + ), + migrations.AddField( + model_name='project', + name='training_project', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='engine.trainingproject'), + ), + ] diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index d9fcda7743e8..bcc467386fc9 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -2,15 +2,16 @@ # # SPDX-License-Identifier: MIT -from enum import Enum -import re import os +import re +from enum import Enum -from django.db import models from django.conf import settings - from django.contrib.auth.models import User from django.core.files.storage import FileSystemStorage +from django.db import models +from django.utils.translation import gettext_lazy as _ + class SafeCharField(models.CharField): def get_prep_value(self, value): @@ -19,6 +20,7 @@ def get_prep_value(self, value): return value[:self.max_length] return value + class DimensionType(str, Enum): DIM_3D = '3d' DIM_2D = '2d' @@ -152,6 +154,7 @@ class Video(models.Model): class Meta: default_permissions = () + class Image(models.Model): data = models.ForeignKey(Data, on_delete=models.CASCADE, related_name="images", null=True) path = models.CharField(max_length=1024, default='') @@ -162,17 +165,32 @@ class Image(models.Model): class Meta: default_permissions = () + +class TrainingProject(models.Model): + class ProjectClass(models.TextChoices): + DETECTION = 'OD', _('Object Detection') + + host = models.CharField(max_length=256) + username = models.CharField(max_length=256) + password = models.CharField(max_length=256) + training_id = models.CharField(max_length=64) + enabled = models.BooleanField(null=True) + project_class = models.CharField(max_length=2, choices=ProjectClass.choices, null=True, blank=True) + + class Project(models.Model): + name = SafeCharField(max_length=256) owner = models.ForeignKey(User, null=True, blank=True, - on_delete=models.SET_NULL, related_name="+") - assignee = models.ForeignKey(User, null=True, blank=True, - on_delete=models.SET_NULL, related_name="+") + on_delete=models.SET_NULL, related_name="+") + assignee = models.ForeignKey(User, null=True, blank=True, + on_delete=models.SET_NULL, related_name="+") bug_tracker = models.CharField(max_length=2000, blank=True, default="") created_date = models.DateTimeField(auto_now_add=True) updated_date = models.DateTimeField(auto_now_add=True) status = models.CharField(max_length=32, choices=StatusChoice.choices(), - default=StatusChoice.ANNOTATION) + default=StatusChoice.ANNOTATION) + training_project = models.ForeignKey(TrainingProject, null=True, blank=True, on_delete=models.SET_NULL) def get_project_dirname(self): return os.path.join(settings.PROJECTS_ROOT, str(self.id)) @@ -210,7 +228,7 @@ class Task(models.Model): # Zero means that there are no limits (default) segment_size = models.PositiveIntegerField(default=0) status = models.CharField(max_length=32, choices=StatusChoice.choices(), - default=StatusChoice.ANNOTATION) + default=StatusChoice.ANNOTATION) data = models.ForeignKey(Data, on_delete=models.CASCADE, null=True, related_name="tasks") dimension = models.CharField(max_length=2, choices=DimensionType.choices(), default=DimensionType.DIM_2D) subset = models.CharField(max_length=64, blank=True, default="") @@ -237,6 +255,13 @@ def get_task_artifacts_dirname(self): def __str__(self): return self.name + +class TrainingProjectImage(models.Model): + task = models.ForeignKey(Task, on_delete=models.CASCADE) + idx = models.PositiveIntegerField() + training_image_id = models.CharField(max_length=64) + + # Redefined a couple of operation for FileSystemStorage to avoid renaming # or other side effects. class MyFileSystemStorage(FileSystemStorage): @@ -319,6 +344,12 @@ class Meta: default_permissions = () unique_together = ('task', 'name') + +class TrainingProjectLabel(models.Model): + cvat_label = models.ForeignKey(Label, on_delete=models.CASCADE, related_name='training_project_label') + training_label_id = models.CharField(max_length=64) + + class AttributeType(str, Enum): CHECKBOX = 'checkbox' RADIO = 'radio' diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index f4c7b66a1109..dfbe4fa62807 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -9,9 +9,11 @@ from rest_framework import serializers, exceptions from django.contrib.auth.models import User, Group + +from cvat.apps.dataset_manager.formats.utils import get_label_color from cvat.apps.engine import models from cvat.apps.engine.log import slogger -from cvat.apps.dataset_manager.formats.utils import get_label_color + class BasicUserSerializer(serializers.ModelSerializer): def validate(self, data): @@ -415,6 +417,7 @@ def validate_labels(self, value): raise serializers.ValidationError('All label names must be unique for the task') return value + class ProjectSearchSerializer(serializers.ModelSerializer): class Meta: model = models.Project @@ -423,17 +426,25 @@ class Meta: ordering = ['-id'] +class TrainingProjectSerializer(serializers.ModelSerializer): + class Meta: + model = models.TrainingProject + fields = ('host', 'username', 'password', 'enabled', 'project_class') + write_once_fields = ('host', 'username', 'password', 'project_class') + + class ProjectWithoutTaskSerializer(serializers.ModelSerializer): labels = LabelSerializer(many=True, source='label_set', partial=True, default=[]) owner = BasicUserSerializer(required=False) owner_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) assignee = BasicUserSerializer(allow_null=True, required=False) assignee_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) + training_project = TrainingProjectSerializer(required=False, allow_null=True) class Meta: model = models.Project - fields = ('url', 'id', 'name', 'labels', 'owner', 'assignee', 'owner_id', 'assignee_id', - 'bug_tracker', 'created_date', 'updated_date', 'status') + fields = ('url', 'id', 'name', 'labels', 'tasks', 'owner', 'assignee', 'owner_id', 'assignee_id', + 'bug_tracker', 'created_date', 'updated_date', 'status', 'training_project') read_only_fields = ('created_date', 'updated_date', 'status', 'owner', 'asignee') ordering = ['-id'] @@ -456,7 +467,17 @@ class Meta(ProjectWithoutTaskSerializer.Meta): # pylint: disable=no-self-use def create(self, validated_data): labels = validated_data.pop('label_set') - db_project = models.Project.objects.create(**validated_data) + training_data = validated_data.pop('training_project', {}) + if training_data.get('enabled'): + host = training_data.pop('host').strip('/') + username = training_data.pop('username').strip() + password = training_data.pop('password').strip() + tr_p = models.TrainingProject.objects.create(**training_data, + host=host, username=username, password=password) + db_project = models.Project.objects.create(**validated_data, + training_project=tr_p) + else: + db_project = models.Project.objects.create(**validated_data) label_names = list() for label in labels: attributes = label.pop('attributespec_set') @@ -472,7 +493,6 @@ def create(self, validated_data): shutil.rmtree(project_path) os.makedirs(db_project.get_project_logs_dirname()) - db_project.save() return db_project # pylint: disable=no-self-use @@ -530,6 +550,7 @@ class PluginsSerializer(serializers.Serializer): GIT_INTEGRATION = serializers.BooleanField() ANALYTICS = serializers.BooleanField() MODELS = serializers.BooleanField() + PREDICT = serializers.BooleanField() class DataMetaSerializer(serializers.ModelSerializer): frames = FrameMetaSerializer(many=True, allow_null=True) diff --git a/cvat/apps/engine/urls.py b/cvat/apps/engine/urls.py index da0c1f2e6bbe..abc9110811a9 100644 --- a/cvat/apps/engine/urls.py +++ b/cvat/apps/engine/urls.py @@ -13,6 +13,7 @@ from django.conf import settings from cvat.apps.restrictions.views import RestrictionsViewSet from cvat.apps.authentication.decorators import login_required +from cvat.apps.training.views import PredictView schema_view = get_schema_view( openapi.Info( @@ -53,6 +54,7 @@ def _map_format_to_schema(request, scheme=None): router.register('issues', views.IssueViewSet) router.register('comments', views.CommentViewSet) router.register('restrictions', RestrictionsViewSet, basename='restrictions') +router.register('predict', PredictView, basename='predict') urlpatterns = [ # Entry point for a client diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index d15136b69d3c..89751952066f 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -2,23 +2,23 @@ # # SPDX-License-Identifier: MIT +import io import os import os.path as osp -import io import shutil import traceback from datetime import datetime from distutils.util import strtobool from tempfile import mkstemp -import cv2 +import cv2 import django_rq -from django.shortcuts import get_object_or_404 from django.apps import apps from django.conf import settings from django.contrib.auth.models import User from django.db import IntegrityError from django.http import HttpResponse +from django.shortcuts import get_object_or_404 from django.utils import timezone from django.utils.decorators import method_decorator from django_filters import rest_framework as filters @@ -35,7 +35,7 @@ from sendfile import sendfile import cvat.apps.dataset_manager as dm -import cvat.apps.dataset_manager.views # pylint: disable=unused-import +import cvat.apps.dataset_manager.views # pylint: disable=unused-import from cvat.apps.authentication import auth from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer @@ -53,7 +53,6 @@ CombinedReviewSerializer, IssueSerializer, CombinedIssueSerializer, CommentSerializer ) from cvat.apps.engine.utils import av_scan_paths - from . import models, task from .log import clogger, slogger @@ -188,6 +187,7 @@ def plugins(request): 'GIT_INTEGRATION': apps.is_installed('cvat.apps.dataset_repo'), 'ANALYTICS': False, 'MODELS': False, + 'PREDICT': apps.is_installed('cvat.apps.training') } if strtobool(os.environ.get("CVAT_ANALYTICS", '0')): response['ANALYTICS'] = True @@ -290,6 +290,7 @@ def tasks(self, request, pk): context={"request": request}) return Response(serializer.data) + class TaskFilter(filters.FilterSet): project = filters.CharFilter(field_name="project__name", lookup_expr="icontains") name = filters.CharFilter(field_name="name", lookup_expr="icontains") @@ -1109,3 +1110,5 @@ def _export_annotations(db_task, rq_id, request, format_name, action, callback, meta={ 'request_time': timezone.localtime() }, result_ttl=ttl, failure_ttl=ttl) return Response(status=status.HTTP_202_ACCEPTED) + + diff --git a/cvat/apps/training/__init__.py b/cvat/apps/training/__init__.py new file mode 100644 index 000000000000..2bb1b0c81b7f --- /dev/null +++ b/cvat/apps/training/__init__.py @@ -0,0 +1 @@ +default_app_config = 'cvat.apps.training.apps.TrainingConfig' diff --git a/cvat/apps/training/apis.py b/cvat/apps/training/apis.py new file mode 100644 index 000000000000..d280f34a69f8 --- /dev/null +++ b/cvat/apps/training/apis.py @@ -0,0 +1,362 @@ +import uuid +from abc import ABC, abstractmethod +from collections import OrderedDict +from functools import wraps +from typing import Callable, List, Union + +import requests + +from cacheops import cache, CacheMiss + +from cvat.apps.engine.models import TrainingProject, ShapeType + + +class TrainingServerAPIAbs(ABC): + + def __init__(self, host, username, password): + self.host = host + self.username = username + self.password = password + + @abstractmethod + def get_server_status(self): + pass + + @abstractmethod + def create_project(self, name: str, description: str = '', project_class: TrainingProject.ProjectClass = None, + labels: List[dict] = None): + pass + + @abstractmethod + def upload_annotations(self, project_id: str, frames_data: List[dict]): + pass + + @abstractmethod + def get_project_status(self, project_id: str) -> dict: + pass + + @abstractmethod + def get_annotation(self, project_id: str, image_id: str, width: int, height: int, frame: int, + labels_mapping: dict) -> dict: + pass + + +def retry(amount: int = 2) -> Callable: + def dec(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + __amount = amount + while __amount > 0: + __amount -= 1 + try: + result = func(*args, **kwargs) + return result + except Exception: + pass + + return wrapper + + return dec + + +class TrainingServerAPI(TrainingServerAPIAbs): + TRAINING_CLASS = { + TrainingProject.ProjectClass.DETECTION: "DETECTION" + } + + @staticmethod + def __convert_annotation_from_cvat(shapes): + data = [] + for shape in shapes: + x0, y0, x1, y1 = shape['points'] + x = x0 / shape['width'] + y = y0 / shape['height'] + width = (x1 - x0) / shape['width'] + height = (y1 - y0) / shape['height'] + data.append({ + "id": str(uuid.uuid4()), + "shapes": [ + { + "type": "rect", + "geometry": { + "x": x, + "y": y, + "width": width, + "height": height, + "points": None, + } + } + ], + "editor": None, + "labels": [ + { + "id": shape['third_party_label_id'], + "probability": 1.0, + }, + ], + }) + return data + + @staticmethod + def __convert_annotation_to_cvat(annotation: dict, image_width: int, image_height: int, frame: int, + labels_mapping: dict) -> List[OrderedDict]: + shapes = [] + for i, annotation in enumerate(annotation.get('data', [])): + label_id = annotation['labels'][0]['id'] + if not labels_mapping.get(label_id): + continue + shape = annotation['shapes'][0] + if shape['type'] != 'rect': + continue + x = shape['geometry']['x'] + y = shape['geometry']['y'] + w = shape['geometry']['width'] + h = shape['geometry']['height'] + x0 = x * image_width + y0 = y * image_height + x1 = image_width * w + x0 + y1 = image_height * h + y0 + shapes.append(OrderedDict([ + ('type', ShapeType.RECTANGLE), + ('occluded', False), + ('z_order', 0), + ('points', [x0, y0, x1, y1]), + ('id', i), + ('frame', int(frame)), + ('label', labels_mapping.get(label_id)), + ('group', 0), + ('source', 'auto'), + ('attributes', {}) + ])) + return shapes + + @retry() + def __create_project(self, name: str, description: str = None, + labels: List[dict] = None, tasks: List[dict] = None) -> dict: + url = f'{self.host}/v2/projects' + headers = { + 'Context-Type': 'application/json', + 'Authorization': f'bearer_token {self.token}', + } + tasks[1]['properties'] = [ + { + "id": "labels", + "user_value": labels + } + ] + data = { + 'name': name, + 'description': description, + "dimensions": [], + "group_type": "normal", + 'pipeline': { + 'connections': [{ + 'from': { + **tasks[0]['output_ports'][0], + 'task_id': tasks[0]['temp_id'], + }, + 'to': { + **tasks[1]['input_ports'][0], + 'task_id': tasks[1]['temp_id'], + } + }], + 'tasks': tasks, + }, + "pipeline_representation": 'Detection', + "type": "project", + } + response = self.request(method='POST', url=url, json=data, headers=headers) + return response + + @retry() + def __get_annotation(self, project_id: str, image_id: str) -> dict: + url = f'{self.host}/v2/projects/{project_id}/media/images/{image_id}/results/online' + headers = { + 'Authorization': f'bearer_token {self.token}', + } + response = self.request(method='GET', url=url, headers=headers) + return response + + @retry() + def __get_job_status(self, project_id: str) -> dict: + url = f'{self.host}/v2/projects/{project_id}/jobs' + headers = { + 'Authorization': f'bearer_token {self.token}', + } + response = self.request(method='GET', url=url, headers=headers) + return response + + @retry() + def __get_project_summary(self, project_id: str) -> dict: + url = f'{self.host}/v2/projects/{project_id}/statistics/summary' + headers = { + 'Authorization': f'bearer_token {self.token}', + } + response = self.request(method='GET', url=url, headers=headers) + return response + + @retry() + def __get_project(self, project_id: str) -> dict: + url = f'{self.host}/v2/projects/{project_id}' + headers = { + 'Authorization': f'bearer_token {self.token}', + } + response = self.request(method='GET', url=url, headers=headers) + return response + + @retry() + def __get_server_status(self) -> dict: + url = f'{self.host}/v2/status' + headers = { + 'Authorization': f'bearer_token {self.token}', + } + response = self.request(method='GET', url=url, headers=headers) + return response + + @retry() + def __get_tasks(self) -> List[dict]: + url = f'{self.host}/v2/tasks' + headers = { + 'Authorization': f'bearer_token {self.token}', + } + response = self.request(method='GET', url=url, headers=headers) + return response + + def __delete_token(self): + cache.delete(self.token_key) + + @retry() + def __upload_annotation(self, project_id: str, image_id: str, annotation: List[dict]): + url = f'{self.host}/v2/projects/{project_id}/media/images/{image_id}/annotations' + headers = { + 'Authorization': f'bearer_token {self.token}', + 'Content-Type': 'application/json' + } + data = { + 'image_id': image_id, + 'data': annotation + } + response = self.request(method='POST', url=url, headers=headers, json=data) + return response + + @retry() + def __upload_image(self, project_id: str, buffer) -> dict: + url = f'{self.host}/v2/projects/{project_id}/media/images' + files = {'file': buffer} + headers = { + 'Authorization': f'bearer_token {self.token}', + } + response = self.request(method='POST', url=url, headers=headers, files=files) + return response + + @property + def project_id_key(self): + return f'{self.host}_{self.username}_project_id' + + @property + def token(self) -> str: + def get_token(host: str, username: str, password: str) -> dict: + url = f'{host}/v2/authentication' + data = { + 'username': (None, username), + 'password': (None, password), + } + r = requests.post(url=url, files=data, verify=False) # nosec + return r.json() + + try: + token = cache.get(self.token_key) + except CacheMiss: + response = get_token(self.host, self.username, self.password) + token = response.get('secure_token', '') + expires_in = response.get('expires_in', 3600) + cache.set(cache_key=self.token_key, data=token, timeout=expires_in) + return token + + @property + def token_key(self): + return f'{self.host}_{self.username}_token' + + def request(self, method: str, url: str, **kwargs) -> Union[list, dict, str]: + response = requests.request(method=method, url=url, verify=False, **kwargs) + if response.status_code == 401: + self.__delete_token() + raise Exception("401") + result = response.json() + return result + + def create_project(self, name: str, description: str = '', project_class: TrainingProject.ProjectClass = None, + labels: List[dict] = None) -> dict: + all_tasks = self.__get_tasks() + task_type = self.TRAINING_CLASS.get(project_class) + task_algo = 'Retinanet - TF2' + tasks = [ + next(({'temp_id': '_1_', **task} + for task in all_tasks + if task['task_type'] == 'DATASET'), {}), + next(({'temp_id': '_2_', **task} + for task in all_tasks + if task['task_type'] == task_type and + task['algorithm_name'] == task_algo), {}), + ] + labels = [{ + 'name': label['name'], + 'temp_id': label['name'] + } for label in labels] + r = self.__create_project(name=name, description=description, tasks=tasks, labels=labels) + return r + + def get_server_status(self) -> dict: + return self.__get_server_status() + + def upload_annotations(self, project_id: str, frames_data: List[dict]): + for frame in frames_data: + annotation = self.__convert_annotation_from_cvat(frame['shapes']) + self.__upload_annotation(project_id=project_id, image_id=frame['third_party_id'], annotation=annotation) + + def upload_image(self, training_id: str, buffer): + response = self.__upload_image(project_id=training_id, buffer=buffer) + return response.get('id') + + def get_project_status(self, project_id) -> dict: + summary = self.__get_project_summary(project_id=project_id) + if not summary or not isinstance(summary, list): + return {'message': 'Not available'} + jobs = self.__get_job_status(project_id=project_id) + media_amount = next(item.get('value', 0) for item in summary if item.get('key') == 'Media') + annotation_amount = next(item.get('value', 0) for item in summary if item.get('key') == 'Annotation') + score = next(item.get('value', 0) for item in summary if item.get('key') == 'Score') + job_items = jobs.get('items', 0) + if len(job_items) == 0 and score == 0: + message = 'Not started' + elif len(job_items) == 0 and score > 0: + message = '' + else: + message = 'In progress' + progress = 0 if len(job_items) == 0 else job_items[0]["status"]["progress"] + time_remaining = 0 if len(job_items) == 0 else job_items[0]["status"]['time_remaining'] + result = { + 'media_amount': media_amount if media_amount else 0, + 'annotation_amount': annotation_amount, + 'score': score, + 'message': message, + 'progress': progress, + 'time_remaining': time_remaining, + } + return result + + def get_annotation(self, project_id: str, image_id: str, width: int, height: int, frame: int, + labels_mapping: dict) -> List[OrderedDict]: + annotation = self.__get_annotation(project_id=project_id, image_id=image_id) + cvat_annotation = self.__convert_annotation_to_cvat(annotation=annotation, image_width=width, + image_height=height, frame=frame, + labels_mapping=labels_mapping) + return cvat_annotation + + def get_labels(self, project_id: str) -> List[dict]: + project = self.__get_project(project_id=project_id) + labels = [{ + 'id': label['id'], + 'name': label['name'] + } for label in project.get('labels')] + return labels diff --git a/cvat/apps/training/apps.py b/cvat/apps/training/apps.py new file mode 100644 index 000000000000..a9ea6f3336e9 --- /dev/null +++ b/cvat/apps/training/apps.py @@ -0,0 +1,11 @@ +from django.apps import AppConfig + + +class TrainingConfig(AppConfig): + name = 'cvat.apps.training' + + def ready(self): + # Required to define signals in application + import cvat.apps.training.signals + # Required in order to silent "unused-import" in pyflake + assert cvat.apps.training.signals diff --git a/cvat/apps/training/jobs.py b/cvat/apps/training/jobs.py new file mode 100644 index 000000000000..3cb50fb55c97 --- /dev/null +++ b/cvat/apps/training/jobs.py @@ -0,0 +1,186 @@ +from collections import OrderedDict +from typing import List + +from cacheops import cache +from django_rq import job + +from cvat.apps import dataset_manager as dm +from cvat.apps.engine.frame_provider import FrameProvider +from cvat.apps.engine.models import ( + Project, + Task, + TrainingProjectImage, + Label, + Image, + TrainingProjectLabel, + Data, + Job, + ShapeType, +) +from cvat.apps.training.apis import TrainingServerAPI + + +@job +def save_prediction_server_status_to_cache_job(cache_key, + cvat_project_id, + timeout=60): + cvat_project = Project.objects.get(pk=cvat_project_id) + api = TrainingServerAPI(host=cvat_project.training_project.host, username=cvat_project.training_project.username, + password=cvat_project.training_project.password) + status = api.get_project_status(project_id=cvat_project.training_project.training_id) + + resp = { + **status, + 'status': 'done' + } + cache.set(cache_key=cache_key, data=resp, timeout=timeout) + + +@job +def save_frame_prediction_to_cache_job(cache_key: str, + task_id: int, + frame: int, + timeout: int = 60): + task = Task.objects.get(pk=task_id) + training_project_image = TrainingProjectImage.objects.filter(idx=frame, task=task).first() + if not training_project_image: + cache.set(cache_key=cache_key, data={ + 'annotation': [], + 'status': 'done' + }, timeout=timeout) + return + + cvat_labels = Label.objects.filter(project__id=task.project_id).all() + training_project = Project.objects.get(pk=task.project_id).training_project + api = TrainingServerAPI(host=training_project.host, + username=training_project.username, + password=training_project.password) + image = Image.objects.get(frame=frame, data=task.data) + labels_mapping = { + TrainingProjectLabel.objects.get(cvat_label=cvat_label).training_label_id: cvat_label.id + for cvat_label in cvat_labels + } + annotation = api.get_annotation(project_id=training_project.training_id, + image_id=training_project_image.training_image_id, + width=image.width, + height=image.height, + labels_mapping=labels_mapping, + frame=frame) + resp = { + 'annotation': annotation, + 'status': 'done' + } + cache.set(cache_key=cache_key, data=resp, timeout=timeout) + + +@job +def upload_images_job(task_id: int): + if TrainingProjectImage.objects.filter(task_id=task_id).count() is 0: + task = Task.objects.get(pk=task_id) + frame_provider = FrameProvider(task.data) + frames = frame_provider.get_frames() + api = TrainingServerAPI( + host=task.project.training_project.host, + username=task.project.training_project.username, + password=task.project.training_project.password, + ) + + for i, (buffer, _) in enumerate(frames): + training_image_id = api.upload_image(training_id=task.project.training_project.training_id, buffer=buffer) + if training_image_id: + TrainingProjectImage.objects.create(task=task, idx=i, + training_image_id=training_image_id) + +def __add_fields_to_shape(shape: dict, frame: int, data: Data, labels_mapping: dict) -> dict: + image = Image.objects.get(frame=frame, data=data) + return { + **shape, + 'height': image.height, + 'width': image.width, + 'third_party_label_id': labels_mapping[shape['label_id']], + } + + +@job +def upload_annotation_to_training_project_job(job_id: int): + cvat_job = Job.objects.get(pk=job_id) + cvat_project = cvat_job.segment.task.project + training_project = cvat_project.training_project + start = cvat_job.segment.start_frame + stop = cvat_job.segment.stop_frame + data = dm.task.get_job_data(job_id) + shapes: List[OrderedDict] = data.get('shapes', []) + frames_data = [] + api = TrainingServerAPI( + host=cvat_project.training_project.host, + username=cvat_project.training_project.username, + password=cvat_project.training_project.password, + ) + cvat_labels = Label.objects.filter(project=cvat_project).all() + labels_mapping = { + cvat_label.id: TrainingProjectLabel.objects.get(cvat_label=cvat_label).training_label_id + for cvat_label in cvat_labels + } + + for frame in range(start, stop + 1): + frame_shapes = list( + map( + lambda x: __add_fields_to_shape(x, frame, cvat_job.segment.task.data, labels_mapping), + filter( + lambda x: x['frame'] == frame and x['type'] == ShapeType.RECTANGLE, + shapes, + ) + ) + ) + + if frame_shapes: + training_project_image = TrainingProjectImage.objects.get(task=cvat_job.segment.task, idx=frame) + frames_data.append({ + 'third_party_id': training_project_image.training_image_id, + 'shapes': frame_shapes + }) + + api.upload_annotations(project_id=training_project.training_id, frames_data=frames_data) + + +@job +def create_training_project_job(project_id: int): + cvat_project = Project.objects.get(pk=project_id) + training_project = cvat_project.training_project + api = TrainingServerAPI( + host=cvat_project.training_project.host, + username=cvat_project.training_project.username, + password=cvat_project.training_project.password, + ) + create_training_project(cvat_project=cvat_project, training_project=training_project, api=api) + + +def create_training_project(cvat_project, training_project, api): + labels = cvat_project.label_set.all() + training_project_resp = api.create_project( + name=f'{cvat_project.name}_cvat', + project_class=training_project.project_class, + labels=[{'name': label.name} for label in labels] + ) + if training_project_resp.get('id'): + training_project.training_id = training_project_resp['id'] + training_project.save() + + for cvat_label in labels: + training_label = list(filter(lambda x: x['name'] == cvat_label.name, training_project_resp.get('labels', []))) + if training_label: + TrainingProjectLabel.objects.create(cvat_label=cvat_label, training_label_id=training_label[0]['id']) + + +async def upload_images(cvat_project_id, training_id, api): + project = Project.objects.get(pk=cvat_project_id) + tasks: List[Task] = project.tasks.all() + for task in tasks: + frame_provider = FrameProvider(task) + frames = frame_provider.get_frames() + for i, (buffer, _) in enumerate(frames): + training_image_id = api.upload_image(training_id=training_id, buffer=buffer) + if training_image_id: + TrainingProjectImage.objects.create(project=project, task=task, idx=i, + training_image_id=training_image_id) + diff --git a/cvat/apps/training/signals.py b/cvat/apps/training/signals.py new file mode 100644 index 000000000000..20ba82420377 --- /dev/null +++ b/cvat/apps/training/signals.py @@ -0,0 +1,30 @@ +from django.db.models.signals import post_save +from django.dispatch import receiver + +from cvat.apps.engine.models import Job, StatusChoice, Project, Task +from cvat.apps.training.jobs import ( + create_training_project_job, + upload_images_job, + upload_annotation_to_training_project_job, +) + + +@receiver(post_save, sender=Project, dispatch_uid="create_training_project") +def create_training_project(instance: Project, **kwargs): + if instance.training_project: + create_training_project_job.delay(instance.id) + + +@receiver(post_save, sender=Task, dispatch_uid='upload_images_to_training_project') +def upload_images_to_training_project(instance: Task, **kwargs): + if (instance.status == StatusChoice.ANNOTATION and + instance.data and instance.data.size != 0 and \ + instance.project_id and instance.project.training_project): + + upload_images_job.delay(instance.id) + + +@receiver(post_save, sender=Job, dispatch_uid="upload_annotation_to_training_project") +def upload_annotation_to_training_project(instance: Job, **kwargs): + if instance.status == StatusChoice.COMPLETED: + upload_annotation_to_training_project_job.delay(instance.id) diff --git a/cvat/apps/training/urls.py b/cvat/apps/training/urls.py new file mode 100644 index 000000000000..47ce86bfb0ef --- /dev/null +++ b/cvat/apps/training/urls.py @@ -0,0 +1,11 @@ +from django.urls import path, include +from rest_framework import routers + +from cvat.apps.training.views import PredictView + +router = routers.DefaultRouter(trailing_slash=False) +router.register('', PredictView, basename='predict') + +urlpatterns = [ + path('', include((router.urls, 'predict'), namespace='predict')) +] diff --git a/cvat/apps/training/views.py b/cvat/apps/training/views.py new file mode 100644 index 000000000000..f6fe5dff979e --- /dev/null +++ b/cvat/apps/training/views.py @@ -0,0 +1,68 @@ +from cacheops import cache, CacheMiss +from drf_yasg.utils import swagger_auto_schema +from rest_framework import viewsets, status +from rest_framework.decorators import action +from rest_framework.permissions import IsAuthenticated, SAFE_METHODS +from rest_framework.response import Response + +from cvat.apps.authentication import auth +from cvat.apps.engine.models import Project +from cvat.apps.training.jobs import save_frame_prediction_to_cache_job, save_prediction_server_status_to_cache_job + + +class PredictView(viewsets.ViewSet): + def get_permissions(self): + http_method = self.request.method + permissions = [IsAuthenticated] + + if http_method in SAFE_METHODS: + permissions.append(auth.ProjectAccessPermission) + else: + permissions.append(auth.AdminRolePermission) + + return [perm() for perm in permissions] + + @swagger_auto_schema(method='get', operation_summary='Returns prediction for image') + @action(detail=False, methods=['GET'], url_path='frame') + def predict_image(self, request): + frame = self.request.query_params.get('frame') + task_id = self.request.query_params.get('task') + if not task_id: + return Response(data='query param "task" empty or not provided', status=status.HTTP_400_BAD_REQUEST) + if not frame: + return Response(data='query param "frame" empty or not provided', status=status.HTTP_400_BAD_REQUEST) + cache_key = f'predict_image_{task_id}_{frame}' + try: + resp = cache.get(cache_key) + except CacheMiss: + save_frame_prediction_to_cache_job.delay(cache_key, task_id=task_id, + frame=frame) + resp = { + 'status': 'queued', + } + cache.set(cache_key=cache_key, data=resp, timeout=60) + + return Response(resp) + + @swagger_auto_schema(method='get', + operation_summary='Returns information of the tasks of the project with the selected id') + @action(detail=False, methods=['GET'], url_path='status') + def predict_status(self, request): + project_id = self.request.query_params.get('project') + if not project_id: + return Response(data='query param "project" empty or not provided', status=status.HTTP_400_BAD_REQUEST) + project = Project.objects.get(pk=project_id) + if not project.training_project: + Response({'status': 'done'}) + + cache_key = f'predict_status_{project_id}' + try: + resp = cache.get(cache_key) + except CacheMiss: + save_prediction_server_status_to_cache_job.delay(cache_key, cvat_project_id=project_id) + resp = { + 'status': 'queued', + } + cache.set(cache_key=cache_key, data=resp, timeout=60) + + return Response(resp) diff --git a/cvat/settings/base.py b/cvat/settings/base.py index f7a694415614..bfcd8d652b86 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -20,6 +20,8 @@ import shutil import subprocess import mimetypes +from distutils.util import strtobool + mimetypes.add_type("application/wasm", ".wasm", True) from pathlib import Path @@ -129,6 +131,9 @@ def add_ssh_keys(): 'rest_auth.registration' ] +if strtobool(os.environ.get("ADAPTIVE_AUTO_ANNOTATION", 'false')): + INSTALLED_APPS.append('cvat.apps.training') + SITE_ID = 1 REST_FRAMEWORK = { diff --git a/cvat/settings/testing.py b/cvat/settings/testing.py index c55e6f421126..e79b0f395df2 100644 --- a/cvat/settings/testing.py +++ b/cvat/settings/testing.py @@ -64,4 +64,4 @@ def __init__(self, *args, **kwargs): for config in RQ_QUEUES.values(): config["ASYNC"] = False - super().__init__(*args, **kwargs) \ No newline at end of file + super().__init__(*args, **kwargs) diff --git a/cvat/urls.py b/cvat/urls.py index 0e25dca54793..1fa4cb50978c 100644 --- a/cvat/urls.py +++ b/cvat/urls.py @@ -43,3 +43,6 @@ if apps.is_installed('silk'): urlpatterns.append(path('profiler/', include('silk.urls'))) + +if apps.is_installed('cvat.apps.training'): + urlpatterns.append(path('api/v1/predict/', include('cvat.apps.training.urls'))) diff --git a/docker-compose.yml b/docker-compose.yml index 8eee1393d490..ce09604465ba 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,6 +42,7 @@ services: ALLOWED_HOSTS: '*' CVAT_REDIS_HOST: 'cvat_redis' CVAT_POSTGRES_HOST: 'cvat_db' + ADAPTIVE_AUTO_ANNOTATION: 'false' volumes: - cvat_data:/home/django/data - cvat_keys:/home/django/keys