From 0d9314cd4a1a605f418e9dc9721375d6b2dfb8ca Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Wed, 30 Nov 2022 21:08:46 +0600 Subject: [PATCH 01/14] Allow to pass job_id for online model invocations --- cvat/apps/engine/models.py | 8 ++++++++ cvat/apps/iam/permissions.py | 22 ++++++++++++++++---- cvat/apps/lambda_manager/views.py | 34 +++++++++++++++++++++++++------ 3 files changed, 54 insertions(+), 10 deletions(-) diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index 44ca08ee1c4c..bdca7995167a 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -446,6 +446,9 @@ class Segment(models.Model): start_frame = models.IntegerField() stop_frame = models.IntegerField() + def contains_frame(self, idx: int) -> bool: + return self.start_frame <= idx and idx < self.stop_frame + class Meta: default_permissions = () @@ -472,6 +475,11 @@ def get_project_id(self): project = self.segment.task.project return project.id if project else None + @extend_schema_field(OpenApiTypes.INT) + def get_task_id(self): + task = self.segment.task + return task.id if task else None + def get_organization_id(self): return self.segment.task.organization diff --git a/cvat/apps/iam/permissions.py b/cvat/apps/iam/permissions.py index 5952ee7e5a4a..b28d63a0777c 100644 --- a/cvat/apps/iam/permissions.py +++ b/cvat/apps/iam/permissions.py @@ -365,14 +365,20 @@ class LambdaPermission(OpenPolicyAgentPermission): def create(cls, request, view, obj): permissions = [] if view.basename == 'function' or view.basename == 'request': - for scope in cls.get_scopes(request, view, obj): + scopes = cls.get_scopes(request, view, obj) + for scope in scopes: self = cls.create_base_perm(request, view, scope, obj) permissions.append(self) - task_id = request.data.get('task') - if task_id: - perm = TaskPermission.create_scope_view_data(request, task_id) + job_id = request.data.get('job') + if job_id: + perm = JobPermission.create_scope_view_data(request, job_id) permissions.append(perm) + else: + task_id = request.data.get('task') + if task_id: + perm = TaskPermission.create_scope_view_data(request, task_id) + permissions.append(perm) return permissions @@ -879,6 +885,14 @@ def create(cls, request, view, obj): return permissions + @classmethod + def create_scope_view_data(cls, request, job_id): + try: + obj = Job.objects.get(id=job_id) + except Job.DoesNotExist as ex: + raise ValidationError(str(ex)) + return cls(**cls.unpack_context(request), obj=obj, scope='view:data') + def __init__(self, **kwargs): super().__init__(**kwargs) self.url = settings.IAM_OPA_DATA_URL + '/jobs/allow' diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 9a9209311f2e..938217927b21 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -8,6 +8,7 @@ from functools import wraps from enum import Enum from copy import deepcopy +from typing import Any, Dict, Optional import django_rq import requests @@ -21,7 +22,7 @@ import cvat.apps.dataset_manager as dm from cvat.apps.engine.frame_provider import FrameProvider -from cvat.apps.engine.models import Task as TaskModel +from cvat.apps.engine.models import Job, Task from cvat.apps.engine.serializers import LabeledDataSerializer from cvat.apps.engine.models import ShapeType, SourceType @@ -175,7 +176,7 @@ def to_dict(self): return response - def invoke(self, db_task, data): + def invoke(self, db_task: Task, data: Dict[str, Any], *, db_job: Optional[Job] = None): try: payload = {} data = {k: v for k,v in data.items() if v is not None} @@ -226,10 +227,16 @@ def invoke(self, db_task, data): supported_attrs[func_label].update({ attr["name"]: task_attributes[mapped_label][mapped_attr] }) if self.kind == LambdaType.DETECTOR: + if db_job and not db_job.segment.contains_frame(data["frame"]): + raise ValidationError("the frame is outside the job range") + payload.update({ "image": self._get_image(db_task, data["frame"], quality) }) elif self.kind == LambdaType.INTERACTOR: + if db_job and not db_job.segment.contains_frame(data["frame"]): + raise ValidationError("the frame is outside the job range") + payload.update({ "image": self._get_image(db_task, data["frame"], quality), "pos_points": data["pos_points"][2:] if self.startswith_box else data["pos_points"], @@ -237,6 +244,11 @@ def invoke(self, db_task, data): "obj_bbox": data["pos_points"][0:2] if self.startswith_box else None }) elif self.kind == LambdaType.REID: + if db_job and not db_job.segment.contains_frame(data["frame0"]): + raise ValidationError("the start frame is outside the job range") + if db_job and not db_job.segment.contains_frame(data["frame1"]): + raise ValidationError("the end frame is outside the job range") + payload.update({ "image0": self._get_image(db_task, data["frame0"], quality), "image1": self._get_image(db_task, data["frame1"], quality), @@ -249,6 +261,9 @@ def invoke(self, db_task, data): "max_distance": max_distance }) elif self.kind == LambdaType.TRACKER: + if db_job and not db_job.segment.contains_frame(data["frame"]): + raise ValidationError("the frame is outside the job range") + payload.update({ "image": self._get_image(db_task, data["frame"], quality), "shapes": data.get("shapes", []), @@ -647,7 +662,7 @@ def _call_reid(function, db_task, quality, threshold, max_distance): @staticmethod def __call__(function, task, quality, cleanup, **kwargs): # TODO: need logging - db_task = TaskModel.objects.get(pk=task) + db_task = Task.objects.get(pk=task) if cleanup: dm.task.delete_task_data(db_task.id) db_labels = (db_task.project.label_set if db_task.project_id else db_task.label_set).prefetch_related("attributespec_set").all() @@ -729,8 +744,15 @@ def retrieve(self, request, func_id): def call(self, request, func_id): self.check_object_permissions(request, func_id) try: - task_id = request.data['task'] - db_task = TaskModel.objects.get(pk=task_id) + job_id = request.data.get('job') + job = None + if job_id is not None: + job = Job.objects.get(id=job_id) + task_id = job.get_task_id() + else: + task_id = request.data['task'] + + db_task = Task.objects.get(pk=task_id) except (KeyError, ObjectDoesNotExist) as err: raise ValidationError( '`{}` lambda function was run '.format(func_id) + @@ -740,7 +762,7 @@ def call(self, request, func_id): gateway = LambdaGateway() lambda_func = gateway.get(func_id) - return lambda_func.invoke(db_task, request.data) + return lambda_func.invoke(db_task, request.data, db_job=job) @extend_schema(tags=['lambda']) @extend_schema_view( From 9c2ff21938803d7f2935925aad4af6e6927349c0 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Wed, 30 Nov 2022 21:51:01 +0600 Subject: [PATCH 02/14] Update the UI --- .../standard-workspace/controls-side-bar/tools-control.tsx | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx index 6258b0dc1162..ccca7f810811 100644 --- a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx +++ b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx @@ -740,6 +740,7 @@ export class ToolsControlComponent extends React.PureComponent { const response = await core.lambda.call(jobInstance.taskId, tracker, { frame: frame - 1, shapes: trackableObjects.shapes, + job: jobInstance.id, }); const { states: serverlessStates } = response; @@ -787,6 +788,7 @@ export class ToolsControlComponent extends React.PureComponent { frame, shapes: trackableObjects.shapes, states: trackableObjects.states, + job: jobInstance.id, }); response.shapes = response.shapes.map(trackedRectangleMapper); @@ -1161,7 +1163,9 @@ export class ToolsControlComponent extends React.PureComponent { runInference={async (model: Model, body: DetectorRequestBody) => { try { this.setState({ mode: 'detection', fetching: true }); - const result = await core.lambda.call(jobInstance.taskId, model, { ...body, frame }); + const result = await core.lambda.call(jobInstance.taskId, model, { + ...body, frame, job: jobInstance.id, + }); const states = result.map( (data: any): any => { const jobLabel = (jobInstance.labels as Label[]) From e78949c736b16b5418b50d3ec35af74b3f503796 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 1 Dec 2022 00:13:38 +0600 Subject: [PATCH 03/14] Send job id in UI --- .../standard-workspace/controls-side-bar/tools-control.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx index ccca7f810811..4fd5fec42ec3 100644 --- a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx +++ b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx @@ -369,7 +369,8 @@ export class ToolsControlComponent extends React.PureComponent { try { // run server request this.setState({ fetching: true }); - const response = await core.lambda.call(jobInstance.taskId, interactor, data); + const response = await core.lambda.call(jobInstance.taskId, interactor, + { ...data, job: jobInstance.id }); // approximation with cv.approxPolyDP const approximated = await this.approximateResponsePoints(response.points); From 0e4719dc76e397dab2bf8342e3d60742da541585 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 1 Dec 2022 00:14:46 +0600 Subject: [PATCH 04/14] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 54de3684ba6d..f0ab1d9c73a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,6 +82,7 @@ non-ascii paths while adding files from "Connected file share" (issue #4428) - Visibility and ignored information fail to be loaded (MOT dataset format) () - Added force logout on CVAT app start if token is missing () - Missed token with using social account authentication () +- A permission denied problem with interactive model launches for workers in orgs () ### Security - TDB From 84237d6b1ce8ba4fd0c6405b59979c10cddeba87 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 1 Dec 2022 16:16:25 +0600 Subject: [PATCH 05/14] Fix job boundary check --- cvat/apps/engine/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index bdca7995167a..53b156f747af 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -447,7 +447,7 @@ class Segment(models.Model): stop_frame = models.IntegerField() def contains_frame(self, idx: int) -> bool: - return self.start_frame <= idx and idx < self.stop_frame + return self.start_frame <= idx and idx <= self.stop_frame class Meta: default_permissions = () From 6a076bb24c9b3280d2797a05d7c110a0a1ea14cb Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 1 Dec 2022 16:20:54 +0600 Subject: [PATCH 06/14] Add function doc --- cvat/apps/lambda_manager/views.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 938217927b21..b6eb59f3481c 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -740,6 +740,19 @@ def retrieve(self, request, func_id): gateway = LambdaGateway() return gateway.get(func_id).to_dict() + @extend_schema(description=textwrap.dedent("""\ + Allows to execute a function for immediate computation. + + Supposed for short-living executions, useful for interactive calls. + When executed for interactive annotation, the job id must be specified + in the 'job' input field. + """), + request=inline_serializer("OnlineFunctionCall", fields={ + "job": serializers.IntegerField(required=False), + "task": serializers.IntegerField(required=False), + }), + responses=OpenApiResponse(description="Returns function invocation results") + ) @return_response() def call(self, request, func_id): self.check_object_permissions(request, func_id) From 83710dadc93fdf93f23af4a6ff0ba73096da6d29 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 1 Dec 2022 16:21:12 +0600 Subject: [PATCH 07/14] Fix error message --- cvat/apps/lambda_manager/views.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index b6eb59f3481c..603debe5ff76 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -8,6 +8,7 @@ from functools import wraps from enum import Enum from copy import deepcopy +import textwrap from typing import Any, Dict, Optional import django_rq @@ -17,7 +18,7 @@ from django.conf import settings from django.core.exceptions import ObjectDoesNotExist, ValidationError -from rest_framework import status, viewsets +from rest_framework import status, viewsets, serializers from rest_framework.response import Response import cvat.apps.dataset_manager as dm @@ -26,7 +27,8 @@ from cvat.apps.engine.serializers import LabeledDataSerializer from cvat.apps.engine.models import ShapeType, SourceType -from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiResponse, OpenApiParameter +from drf_spectacular.utils import (extend_schema, extend_schema_view, + OpenApiResponse, OpenApiParameter, inline_serializer) from drf_spectacular.types import OpenApiTypes class LambdaType(Enum): @@ -228,14 +230,16 @@ def invoke(self, db_task: Task, data: Dict[str, Any], *, db_job: Optional[Job] = if self.kind == LambdaType.DETECTOR: if db_job and not db_job.segment.contains_frame(data["frame"]): - raise ValidationError("the frame is outside the job range") + raise ValidationError("the frame is outside the job range", + code=status.HTTP_400_BAD_REQUEST) payload.update({ "image": self._get_image(db_task, data["frame"], quality) }) elif self.kind == LambdaType.INTERACTOR: if db_job and not db_job.segment.contains_frame(data["frame"]): - raise ValidationError("the frame is outside the job range") + raise ValidationError("the frame is outside the job range", + code=status.HTTP_400_BAD_REQUEST) payload.update({ "image": self._get_image(db_task, data["frame"], quality), @@ -245,9 +249,11 @@ def invoke(self, db_task: Task, data: Dict[str, Any], *, db_job: Optional[Job] = }) elif self.kind == LambdaType.REID: if db_job and not db_job.segment.contains_frame(data["frame0"]): - raise ValidationError("the start frame is outside the job range") + raise ValidationError("the start frame is outside the job range", + code=status.HTTP_400_BAD_REQUEST) if db_job and not db_job.segment.contains_frame(data["frame1"]): - raise ValidationError("the end frame is outside the job range") + raise ValidationError("the end frame is outside the job range", + code=status.HTTP_400_BAD_REQUEST) payload.update({ "image0": self._get_image(db_task, data["frame0"], quality), @@ -262,7 +268,8 @@ def invoke(self, db_task: Task, data: Dict[str, Any], *, db_job: Optional[Job] = }) elif self.kind == LambdaType.TRACKER: if db_job and not db_job.segment.contains_frame(data["frame"]): - raise ValidationError("the frame is outside the job range") + raise ValidationError("the frame is outside the job range", + code=status.HTTP_400_BAD_REQUEST) payload.update({ "image": self._get_image(db_task, data["frame"], quality), @@ -700,7 +707,7 @@ def func_wrapper(*args, **kwargs): status_code = status.HTTP_500_INTERNAL_SERVER_ERROR data = str(err) except ValidationError as err: - status_code = err.code + status_code = err.code or status.HTTP_400_BAD_REQUEST data = err.message except ObjectDoesNotExist as err: status_code = status.HTTP_400_BAD_REQUEST From 97b3f2b3b3646ff822c6c208a1a549c6d06295ca Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 1 Dec 2022 16:21:44 +0600 Subject: [PATCH 08/14] Add tests --- cvat/apps/lambda_manager/tests/test_lambda.py | 227 ++++++++++++++++-- 1 file changed, 204 insertions(+), 23 deletions(-) diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py index 4abfcf3b5869..1e6698e7f9ff 100644 --- a/cvat/apps/lambda_manager/tests/test_lambda.py +++ b/cvat/apps/lambda_manager/tests/test_lambda.py @@ -6,6 +6,7 @@ import json from collections import OrderedDict from io import BytesIO +from typing import Optional from unittest import mock, skip import os @@ -71,7 +72,7 @@ def __exit__(self, exception_type, exception_value, traceback): if self.user: self.client.logout() -class LambdaTestCase(APITestCase): +class _LambdaTestCaseBase(APITestCase): def setUp(self): self.client = APIClient() @@ -83,11 +84,6 @@ def setUp(self): self.addCleanup(invoke_patcher.stop) invoke_patcher.start() - images_main_task = self._generate_task_images(3) - images_assigneed_to_user_task = self._generate_task_images(3) - self.main_task = self._create_task(tasks["main"], images_main_task) - self.assigneed_to_user_task = self._create_task(tasks["assigneed_to_user"], images_assigneed_to_user_task) - def __get_data_from_lambda_manager_http(self, **kwargs): url = kwargs["url"] if url == "/api/functions": @@ -143,24 +139,28 @@ def _create_db_users(cls): user_admin = User.objects.create_superuser(username="admin", email="", password="admin") user_admin.groups.add(group_admin) - user_dummy = User.objects.create_user(username="user", password="user") + user_dummy = User.objects.create_user(username="user", password="user", + email="user@example.com") user_dummy.groups.add(group_user) cls.admin = user_admin cls.user = user_dummy - def _create_task(self, data, image_data): - with ForceLogin(self.admin, self.client): - response = self.client.post('/api/tasks', data=data, format="json") + def _create_task(self, data, image_data, *, owner=None, org_id=None): + with ForceLogin(owner or self.admin, self.client): + response = self.client.post('/api/tasks', data=data, format="json", + QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) assert response.status_code == status.HTTP_201_CREATED, response.status_code tid = response.data["id"] response = self.client.post("/api/tasks/%s/data" % tid, - data=image_data) + data=image_data, + QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code - response = self.client.get("/api/tasks/%s" % tid) + response = self.client.get("/api/tasks/%s" % tid, + QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) task = response.data return task @@ -180,26 +180,37 @@ def setUpTestData(cls): cls._create_db_users() - def _get_request(self, path, user): + def _get_request(self, path, user, *, org_id=None): with ForceLogin(user, self.client): - response = self.client.get(path) + response = self.client.get(path, + QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') return response - def _delete_request(self, path, user): + def _delete_request(self, path, user, *, org_id=None): with ForceLogin(user, self.client): - response = self.client.delete(path) + response = self.client.delete(path, + QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') return response - def _post_request(self, path, user, data): + def _post_request(self, path, user, data, *, org_id=None): data = json.dumps(data) with ForceLogin(user, self.client): - response = self.client.post(path, data=data, content_type='application/json') + response = self.client.post(path, data=data, content_type='application/json', + QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') return response - def __check_expected_keys_in_response_function(self, data): + def _patch_request(self, path, user, data, *, org_id=None): + data = json.dumps(data) + with ForceLogin(user, self.client): + response = self.client.patch(path, data=data, content_type='application/json', + QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') + return response + + + def _check_expected_keys_in_response_function(self, data): kind = data["kind"] if kind == "interactor": for key in expected_keys_in_response_function_interactor: @@ -212,16 +223,27 @@ def __check_expected_keys_in_response_function(self, data): self.assertIn(key, data) +class LambdaTestCases(_LambdaTestCaseBase): + def setUp(self): + super().setUp() + + images_main_task = self._generate_task_images(3) + images_assigneed_to_user_task = self._generate_task_images(3) + self.main_task = self._create_task(tasks["main"], images_main_task) + self.assigneed_to_user_task = self._create_task( + tasks["assigneed_to_user"], images_assigneed_to_user_task + ) + def test_api_v2_lambda_functions_list(self): response = self._get_request(LAMBDA_FUNCTIONS_PATH, self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) for data in response.data: - self.__check_expected_keys_in_response_function(data) + self._check_expected_keys_in_response_function(data) response = self._get_request(LAMBDA_FUNCTIONS_PATH, self.user) self.assertEqual(response.status_code, status.HTTP_200_OK) for data in response.data: - self.__check_expected_keys_in_response_function(data) + self._check_expected_keys_in_response_function(data) response = self._get_request(LAMBDA_FUNCTIONS_PATH, None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) @@ -257,11 +279,11 @@ def test_api_v2_lambda_functions_read(self): response = self._get_request(path, self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.__check_expected_keys_in_response_function(response.data) + self._check_expected_keys_in_response_function(response.data) response = self._get_request(path, self.user) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.__check_expected_keys_in_response_function(response.data) + self._check_expected_keys_in_response_function(response.data) response = self._get_request(path, None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) @@ -966,3 +988,162 @@ def test_api_v2_lambda_functions_create_function_is_not_ready(self): response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_error}", self.admin, data) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + + +class Issue4996_Cases(_LambdaTestCaseBase): + # Check regressions for https://github.com/opencv/cvat/issues/4996#issuecomment-1266123032 + # We need to check that job assignee can call functions in the assigned jobs + # This requires to pass the job id in the call request. + + def _create_org(self, *, owner: int, members: dict[int, str] = None) -> dict: + org = self._post_request('/api/organizations', user=owner, data={ + "slug": "testorg", + "name": "test Org", + }) + assert org.status_code == status.HTTP_201_CREATED + org = org.json() + + for uid, role in members.items(): + user = self._get_request('/api/users/self', user=uid) + assert user.status_code == status.HTTP_200_OK + user = user.json() + + invitation = self._post_request('/api/invitations', user=owner, data={ + 'email': user['email'], + 'role': role, + }, org_id=org['id']) + assert invitation.status_code == status.HTTP_201_CREATED + + return org + + def _set_task_assignee(self, task: int, assignee: Optional[int], *, + org_id: Optional[int] = None): + response = self._patch_request(f'/api/tasks/{task}', user=self.admin, data={ + 'assignee_id': assignee, + }, org_id=org_id) + assert response.status_code == status.HTTP_200_OK + + def _set_job_assignee(self, job: int, assignee: Optional[int], *, + org_id: Optional[int] = None): + response = self._patch_request(f'/api/jobs/{job}', user=self.admin, data={ + 'assignee': assignee, + }, org_id=org_id) + assert response.status_code == status.HTTP_200_OK + + def setUp(self): + self.org = self._create_org(owner=self.admin, members={self.user: 'worker'}) + + task = self._create_task(data={ + 'name': 'test_task', + 'labels': [{'name': 'cat'}], + 'segment_size': 2 + }, + image_data=self._generate_task_images(6), + owner=self.admin, + org_id=self.org['id'], + ) + self.task = task + + jobs = self._get_request(f"/api/tasks/{self.task['id']}/jobs", self.admin, + org_id=self.org['id']) + assert jobs.status_code == status.HTTP_200_OK + self.job = jobs.json()[1] + + self.common_data = { + "task": self.task['id'], + "frame": 0, + "cleanup": True, + "mapping": { + "car": { "name": "car" }, + }, + } + + self.function_name = f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}" + + return super().setUp() + + def _get_valid_job_params(self): + return { + "job": self.job['id'], + "frame": 2 + } + + def _get_invalid_job_params(self): + return { + "job": self.job['id'], + "frame": 0 + } + + def test_can_call_function_for_job_worker_in_org(self): + # not assigned + task => denied + # not assigned + job => denied + # task assignee + task => ok + # job assignee + task => denied + # job assignee + job => ok + + data = self.common_data.copy() + with self.subTest(job=None, assignee=None): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + + data = self.common_data.copy() + data.update(self._get_valid_job_params()) + with self.subTest(job='defined', assignee=None): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + + self._set_task_assignee(self.task['id'], self.user.id, org_id=self.org['id']) + + data = self.common_data.copy() + with self.subTest(job=None, assignee='task'): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self._set_task_assignee(self.task['id'], None) + + + self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + + data = self.common_data.copy() + with self.subTest(job=None, assignee='job'): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + self._set_job_assignee(self.job['id'], None) + + + self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + + data = self.common_data.copy() + data.update(self._get_valid_job_params()) + with self.subTest(job='defined', assignee='job'): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_can_check_job_boundaries_in_function_call(self): + # job + frame (outside) => error (outside) + # job + frame (inside) => ok + + self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + + + data = self.common_data.copy() + data.update(self._get_invalid_job_params()) + with self.subTest(job='defined', frame='outside'): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + data = self.common_data.copy() + data.update(self._get_valid_job_params()) + with self.subTest(job='defined', frame='inside'): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_200_OK) From f192bb7193252586cdb5ff04b7cf0f507d9234b3 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Thu, 1 Dec 2022 17:32:16 +0600 Subject: [PATCH 09/14] Add backward compatibility for python --- cvat/apps/lambda_manager/tests/test_lambda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py index 1e6698e7f9ff..616fa242a2db 100644 --- a/cvat/apps/lambda_manager/tests/test_lambda.py +++ b/cvat/apps/lambda_manager/tests/test_lambda.py @@ -6,7 +6,7 @@ import json from collections import OrderedDict from io import BytesIO -from typing import Optional +from typing import Dict, Optional from unittest import mock, skip import os @@ -995,7 +995,7 @@ class Issue4996_Cases(_LambdaTestCaseBase): # We need to check that job assignee can call functions in the assigned jobs # This requires to pass the job id in the call request. - def _create_org(self, *, owner: int, members: dict[int, str] = None) -> dict: + def _create_org(self, *, owner: int, members: Dict[int, str] = None) -> dict: org = self._post_request('/api/organizations', user=owner, data={ "slug": "testorg", "name": "test Org", From f059164d1c8d5c06601530add7f8fa80e5cb8370 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 5 Dec 2022 13:47:10 +0600 Subject: [PATCH 10/14] Split tests --- cvat/apps/lambda_manager/tests/test_lambda.py | 29 ++++++------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py index 616fa242a2db..b72f3f9f04c3 100644 --- a/cvat/apps/lambda_manager/tests/test_lambda.py +++ b/cvat/apps/lambda_manager/tests/test_lambda.py @@ -1074,20 +1074,14 @@ def _get_invalid_job_params(self): "frame": 0 } - def test_can_call_function_for_job_worker_in_org(self): - # not assigned + task => denied - # not assigned + job => denied - # task assignee + task => ok - # job assignee + task => denied - # job assignee + job => ok - + def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_task_request(self): data = self.common_data.copy() with self.subTest(job=None, assignee=None): response = self._post_request(self.function_name, self.user, data, org_id=self.org['id']) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - + def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_job_request(self): data = self.common_data.copy() data.update(self._get_valid_job_params()) with self.subTest(job='defined', assignee=None): @@ -1095,7 +1089,7 @@ def test_can_call_function_for_job_worker_in_org(self): org_id=self.org['id']) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - + def test_can_call_function_for_job_worker_in_org__allow_task_assigned_worker_with_task_request(self): self._set_task_assignee(self.task['id'], self.user.id, org_id=self.org['id']) data = self.common_data.copy() @@ -1104,9 +1098,7 @@ def test_can_call_function_for_job_worker_in_org(self): org_id=self.org['id']) self.assertEqual(response.status_code, status.HTTP_200_OK) - self._set_task_assignee(self.task['id'], None) - - + def test_can_call_function_for_job_worker_in_org__deny_job_assigned_worker_with_task_request(self): self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) data = self.common_data.copy() @@ -1115,9 +1107,7 @@ def test_can_call_function_for_job_worker_in_org(self): org_id=self.org['id']) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self._set_job_assignee(self.job['id'], None) - - + def test_can_call_function_for_job_worker_in_org__allow_job_assigned_worker_with_job_request(self): self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) data = self.common_data.copy() @@ -1127,13 +1117,9 @@ def test_can_call_function_for_job_worker_in_org(self): org_id=self.org['id']) self.assertEqual(response.status_code, status.HTTP_200_OK) - def test_can_check_job_boundaries_in_function_call(self): - # job + frame (outside) => error (outside) - # job + frame (inside) => ok - + def test_can_check_job_boundaries_in_function_call__fail_for_frame_outside_job(self): self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) - data = self.common_data.copy() data.update(self._get_invalid_job_params()) with self.subTest(job='defined', frame='outside'): @@ -1141,6 +1127,9 @@ def test_can_check_job_boundaries_in_function_call(self): org_id=self.org['id']) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + def test_can_check_job_boundaries_in_function_call__ok_for_frame_inside_job(self): + self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + data = self.common_data.copy() data.update(self._get_valid_job_params()) with self.subTest(job='defined', frame='inside'): From 7a89ec4ecc01da9e533c30be25dd1d79ced09069 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 5 Dec 2022 13:47:23 +0600 Subject: [PATCH 11/14] Refactor some checks --- cvat/apps/iam/permissions.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/cvat/apps/iam/permissions.py b/cvat/apps/iam/permissions.py index b28d63a0777c..079e72e8afc2 100644 --- a/cvat/apps/iam/permissions.py +++ b/cvat/apps/iam/permissions.py @@ -370,15 +370,12 @@ def create(cls, request, view, obj): self = cls.create_base_perm(request, view, scope, obj) permissions.append(self) - job_id = request.data.get('job') - if job_id: + if job_id := request.data.get('job'): perm = JobPermission.create_scope_view_data(request, job_id) permissions.append(perm) - else: - task_id = request.data.get('task') - if task_id: - perm = TaskPermission.create_scope_view_data(request, task_id) - permissions.append(perm) + elif task_id := request.data.get('task'): + perm = TaskPermission.create_scope_view_data(request, task_id) + permissions.append(perm) return permissions From 914d28d5d013cc525658963e194cb0a9566d0c6f Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 5 Dec 2022 14:24:00 +0600 Subject: [PATCH 12/14] Improve wording in docs --- cvat/apps/lambda_manager/views.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 603debe5ff76..967a64b19328 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -750,7 +750,8 @@ def retrieve(self, request, func_id): @extend_schema(description=textwrap.dedent("""\ Allows to execute a function for immediate computation. - Supposed for short-living executions, useful for interactive calls. + Intended for short-lived executions, useful for interactive calls. + When executed for interactive annotation, the job id must be specified in the 'job' input field. """), From 8e34a25122cb5dc15dddba397bd5df0a87a48644 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 5 Dec 2022 14:38:51 +0600 Subject: [PATCH 13/14] Check for task id when job is specified --- cvat/apps/lambda_manager/views.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 967a64b19328..0354fc9699ae 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -180,6 +180,11 @@ def to_dict(self): def invoke(self, db_task: Task, data: Dict[str, Any], *, db_job: Optional[Job] = None): try: + if db_job is not None and db_job.get_task_id() != db_task.id: + raise ValidationError("Job task id does not match task id", + code=status.HTTP_400_BAD_REQUEST + ) + payload = {} data = {k: v for k,v in data.items() if v is not None} threshold = data.get("threshold") @@ -753,7 +758,8 @@ def retrieve(self, request, func_id): Intended for short-lived executions, useful for interactive calls. When executed for interactive annotation, the job id must be specified - in the 'job' input field. + in the 'job' input field. The task id is not required in this case, + but if it is specified, it must match the job task id. """), request=inline_serializer("OnlineFunctionCall", fields={ "job": serializers.IntegerField(required=False), From 4f3dcf6546f19426567a0ab47cb4621cc38da9de Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 5 Dec 2022 15:31:55 +0600 Subject: [PATCH 14/14] Collect frame checks in one place --- cvat/apps/lambda_manager/views.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 0354fc9699ae..3427562a9633 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -233,19 +233,21 @@ def invoke(self, db_task: Task, data: Dict[str, Any], *, db_job: Optional[Job] = if mapped_attr in task_attr_names: supported_attrs[func_label].update({ attr["name"]: task_attributes[mapped_label][mapped_attr] }) - if self.kind == LambdaType.DETECTOR: - if db_job and not db_job.segment.contains_frame(data["frame"]): - raise ValidationError("the frame is outside the job range", + # Check job frame boundaries + for key, desc in ( + ('frame', 'frame'), + ('frame0', 'start frame'), + ('frame1', 'end frame'), + ): + if key in data and db_job and not db_job.segment.contains_frame(data[key]): + raise ValidationError(f"The {desc} is outside the job range", code=status.HTTP_400_BAD_REQUEST) + if self.kind == LambdaType.DETECTOR: payload.update({ "image": self._get_image(db_task, data["frame"], quality) }) elif self.kind == LambdaType.INTERACTOR: - if db_job and not db_job.segment.contains_frame(data["frame"]): - raise ValidationError("the frame is outside the job range", - code=status.HTTP_400_BAD_REQUEST) - payload.update({ "image": self._get_image(db_task, data["frame"], quality), "pos_points": data["pos_points"][2:] if self.startswith_box else data["pos_points"], @@ -253,13 +255,6 @@ def invoke(self, db_task: Task, data: Dict[str, Any], *, db_job: Optional[Job] = "obj_bbox": data["pos_points"][0:2] if self.startswith_box else None }) elif self.kind == LambdaType.REID: - if db_job and not db_job.segment.contains_frame(data["frame0"]): - raise ValidationError("the start frame is outside the job range", - code=status.HTTP_400_BAD_REQUEST) - if db_job and not db_job.segment.contains_frame(data["frame1"]): - raise ValidationError("the end frame is outside the job range", - code=status.HTTP_400_BAD_REQUEST) - payload.update({ "image0": self._get_image(db_task, data["frame0"], quality), "image1": self._get_image(db_task, data["frame1"], quality), @@ -272,10 +267,6 @@ def invoke(self, db_task: Task, data: Dict[str, Any], *, db_job: Optional[Job] = "max_distance": max_distance }) elif self.kind == LambdaType.TRACKER: - if db_job and not db_job.segment.contains_frame(data["frame"]): - raise ValidationError("the frame is outside the job range", - code=status.HTTP_400_BAD_REQUEST) - payload.update({ "image": self._get_image(db_task, data["frame"], quality), "shapes": data.get("shapes", []),