diff --git a/CHANGELOG.md b/CHANGELOG.md index bc78e356c487..4b68c0890164 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -89,6 +89,7 @@ non-ascii paths while adding files from "Connected file share" (issue #4428) - Fixed FBRS serverless function runtime error on images with alpha channel () - Attaching manifest with custom name () - Uploading non-zip annotaion files () +- A permission problem with interactive model launches for workers in orgs () - Fix chart not being upgradable () - Broken helm chart - if using custom release name () - Missing source tag in project annotations () diff --git a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx index 6258b0dc1162..4fd5fec42ec3 100644 --- a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx +++ b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx @@ -369,7 +369,8 @@ export class ToolsControlComponent extends React.PureComponent { try { // run server request this.setState({ fetching: true }); - const response = await core.lambda.call(jobInstance.taskId, interactor, data); + const response = await core.lambda.call(jobInstance.taskId, interactor, + { ...data, job: jobInstance.id }); // approximation with cv.approxPolyDP const approximated = await this.approximateResponsePoints(response.points); @@ -740,6 +741,7 @@ export class ToolsControlComponent extends React.PureComponent { const response = await core.lambda.call(jobInstance.taskId, tracker, { frame: frame - 1, shapes: trackableObjects.shapes, + job: jobInstance.id, }); const { states: serverlessStates } = response; @@ -787,6 +789,7 @@ export class ToolsControlComponent extends React.PureComponent { frame, shapes: trackableObjects.shapes, states: trackableObjects.states, + job: jobInstance.id, }); response.shapes = response.shapes.map(trackedRectangleMapper); @@ -1161,7 +1164,9 @@ export class ToolsControlComponent extends React.PureComponent { runInference={async (model: Model, body: DetectorRequestBody) => { try { this.setState({ mode: 'detection', fetching: true }); - const result = await core.lambda.call(jobInstance.taskId, model, { ...body, frame }); + const result = await core.lambda.call(jobInstance.taskId, model, { + ...body, frame, job: jobInstance.id, + }); const states = result.map( (data: any): any => { const jobLabel = (jobInstance.labels as Label[]) diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index 44ca08ee1c4c..53b156f747af 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -446,6 +446,9 @@ class Segment(models.Model): start_frame = models.IntegerField() stop_frame = models.IntegerField() + def contains_frame(self, idx: int) -> bool: + return self.start_frame <= idx and idx <= self.stop_frame + class Meta: default_permissions = () @@ -472,6 +475,11 @@ def get_project_id(self): project = self.segment.task.project return project.id if project else None + @extend_schema_field(OpenApiTypes.INT) + def get_task_id(self): + task = self.segment.task + return task.id if task else None + def get_organization_id(self): return self.segment.task.organization diff --git a/cvat/apps/iam/permissions.py b/cvat/apps/iam/permissions.py index 5952ee7e5a4a..079e72e8afc2 100644 --- a/cvat/apps/iam/permissions.py +++ b/cvat/apps/iam/permissions.py @@ -365,12 +365,15 @@ class LambdaPermission(OpenPolicyAgentPermission): def create(cls, request, view, obj): permissions = [] if view.basename == 'function' or view.basename == 'request': - for scope in cls.get_scopes(request, view, obj): + scopes = cls.get_scopes(request, view, obj) + for scope in scopes: self = cls.create_base_perm(request, view, scope, obj) permissions.append(self) - task_id = request.data.get('task') - if task_id: + if job_id := request.data.get('job'): + perm = JobPermission.create_scope_view_data(request, job_id) + permissions.append(perm) + elif task_id := request.data.get('task'): perm = TaskPermission.create_scope_view_data(request, task_id) permissions.append(perm) @@ -879,6 +882,14 @@ def create(cls, request, view, obj): return permissions + @classmethod + def create_scope_view_data(cls, request, job_id): + try: + obj = Job.objects.get(id=job_id) + except Job.DoesNotExist as ex: + raise ValidationError(str(ex)) + return cls(**cls.unpack_context(request), obj=obj, scope='view:data') + def __init__(self, **kwargs): super().__init__(**kwargs) self.url = settings.IAM_OPA_DATA_URL + '/jobs/allow' diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py index 4abfcf3b5869..b72f3f9f04c3 100644 --- a/cvat/apps/lambda_manager/tests/test_lambda.py +++ b/cvat/apps/lambda_manager/tests/test_lambda.py @@ -6,6 +6,7 @@ import json from collections import OrderedDict from io import BytesIO +from typing import Dict, Optional from unittest import mock, skip import os @@ -71,7 +72,7 @@ def __exit__(self, exception_type, exception_value, traceback): if self.user: self.client.logout() -class LambdaTestCase(APITestCase): +class _LambdaTestCaseBase(APITestCase): def setUp(self): self.client = APIClient() @@ -83,11 +84,6 @@ def setUp(self): self.addCleanup(invoke_patcher.stop) invoke_patcher.start() - images_main_task = self._generate_task_images(3) - images_assigneed_to_user_task = self._generate_task_images(3) - self.main_task = self._create_task(tasks["main"], images_main_task) - self.assigneed_to_user_task = self._create_task(tasks["assigneed_to_user"], images_assigneed_to_user_task) - def __get_data_from_lambda_manager_http(self, **kwargs): url = kwargs["url"] if url == "/api/functions": @@ -143,24 +139,28 @@ def _create_db_users(cls): user_admin = User.objects.create_superuser(username="admin", email="", password="admin") user_admin.groups.add(group_admin) - user_dummy = User.objects.create_user(username="user", password="user") + user_dummy = User.objects.create_user(username="user", password="user", + email="user@example.com") user_dummy.groups.add(group_user) cls.admin = user_admin cls.user = user_dummy - def _create_task(self, data, image_data): - with ForceLogin(self.admin, self.client): - response = self.client.post('/api/tasks', data=data, format="json") + def _create_task(self, data, image_data, *, owner=None, org_id=None): + with ForceLogin(owner or self.admin, self.client): + response = self.client.post('/api/tasks', data=data, format="json", + QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) assert response.status_code == status.HTTP_201_CREATED, response.status_code tid = response.data["id"] response = self.client.post("/api/tasks/%s/data" % tid, - data=image_data) + data=image_data, + QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code - response = self.client.get("/api/tasks/%s" % tid) + response = self.client.get("/api/tasks/%s" % tid, + QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) task = response.data return task @@ -180,26 +180,37 @@ def setUpTestData(cls): cls._create_db_users() - def _get_request(self, path, user): + def _get_request(self, path, user, *, org_id=None): with ForceLogin(user, self.client): - response = self.client.get(path) + response = self.client.get(path, + QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') return response - def _delete_request(self, path, user): + def _delete_request(self, path, user, *, org_id=None): with ForceLogin(user, self.client): - response = self.client.delete(path) + response = self.client.delete(path, + QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') return response - def _post_request(self, path, user, data): + def _post_request(self, path, user, data, *, org_id=None): data = json.dumps(data) with ForceLogin(user, self.client): - response = self.client.post(path, data=data, content_type='application/json') + response = self.client.post(path, data=data, content_type='application/json', + QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') return response - def __check_expected_keys_in_response_function(self, data): + def _patch_request(self, path, user, data, *, org_id=None): + data = json.dumps(data) + with ForceLogin(user, self.client): + response = self.client.patch(path, data=data, content_type='application/json', + QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') + return response + + + def _check_expected_keys_in_response_function(self, data): kind = data["kind"] if kind == "interactor": for key in expected_keys_in_response_function_interactor: @@ -212,16 +223,27 @@ def __check_expected_keys_in_response_function(self, data): self.assertIn(key, data) +class LambdaTestCases(_LambdaTestCaseBase): + def setUp(self): + super().setUp() + + images_main_task = self._generate_task_images(3) + images_assigneed_to_user_task = self._generate_task_images(3) + self.main_task = self._create_task(tasks["main"], images_main_task) + self.assigneed_to_user_task = self._create_task( + tasks["assigneed_to_user"], images_assigneed_to_user_task + ) + def test_api_v2_lambda_functions_list(self): response = self._get_request(LAMBDA_FUNCTIONS_PATH, self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) for data in response.data: - self.__check_expected_keys_in_response_function(data) + self._check_expected_keys_in_response_function(data) response = self._get_request(LAMBDA_FUNCTIONS_PATH, self.user) self.assertEqual(response.status_code, status.HTTP_200_OK) for data in response.data: - self.__check_expected_keys_in_response_function(data) + self._check_expected_keys_in_response_function(data) response = self._get_request(LAMBDA_FUNCTIONS_PATH, None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) @@ -257,11 +279,11 @@ def test_api_v2_lambda_functions_read(self): response = self._get_request(path, self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.__check_expected_keys_in_response_function(response.data) + self._check_expected_keys_in_response_function(response.data) response = self._get_request(path, self.user) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.__check_expected_keys_in_response_function(response.data) + self._check_expected_keys_in_response_function(response.data) response = self._get_request(path, None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) @@ -966,3 +988,151 @@ def test_api_v2_lambda_functions_create_function_is_not_ready(self): response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_error}", self.admin, data) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + + +class Issue4996_Cases(_LambdaTestCaseBase): + # Check regressions for https://github.com/opencv/cvat/issues/4996#issuecomment-1266123032 + # We need to check that job assignee can call functions in the assigned jobs + # This requires to pass the job id in the call request. + + def _create_org(self, *, owner: int, members: Dict[int, str] = None) -> dict: + org = self._post_request('/api/organizations', user=owner, data={ + "slug": "testorg", + "name": "test Org", + }) + assert org.status_code == status.HTTP_201_CREATED + org = org.json() + + for uid, role in members.items(): + user = self._get_request('/api/users/self', user=uid) + assert user.status_code == status.HTTP_200_OK + user = user.json() + + invitation = self._post_request('/api/invitations', user=owner, data={ + 'email': user['email'], + 'role': role, + }, org_id=org['id']) + assert invitation.status_code == status.HTTP_201_CREATED + + return org + + def _set_task_assignee(self, task: int, assignee: Optional[int], *, + org_id: Optional[int] = None): + response = self._patch_request(f'/api/tasks/{task}', user=self.admin, data={ + 'assignee_id': assignee, + }, org_id=org_id) + assert response.status_code == status.HTTP_200_OK + + def _set_job_assignee(self, job: int, assignee: Optional[int], *, + org_id: Optional[int] = None): + response = self._patch_request(f'/api/jobs/{job}', user=self.admin, data={ + 'assignee': assignee, + }, org_id=org_id) + assert response.status_code == status.HTTP_200_OK + + def setUp(self): + self.org = self._create_org(owner=self.admin, members={self.user: 'worker'}) + + task = self._create_task(data={ + 'name': 'test_task', + 'labels': [{'name': 'cat'}], + 'segment_size': 2 + }, + image_data=self._generate_task_images(6), + owner=self.admin, + org_id=self.org['id'], + ) + self.task = task + + jobs = self._get_request(f"/api/tasks/{self.task['id']}/jobs", self.admin, + org_id=self.org['id']) + assert jobs.status_code == status.HTTP_200_OK + self.job = jobs.json()[1] + + self.common_data = { + "task": self.task['id'], + "frame": 0, + "cleanup": True, + "mapping": { + "car": { "name": "car" }, + }, + } + + self.function_name = f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}" + + return super().setUp() + + def _get_valid_job_params(self): + return { + "job": self.job['id'], + "frame": 2 + } + + def _get_invalid_job_params(self): + return { + "job": self.job['id'], + "frame": 0 + } + + def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_task_request(self): + data = self.common_data.copy() + with self.subTest(job=None, assignee=None): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_job_request(self): + data = self.common_data.copy() + data.update(self._get_valid_job_params()) + with self.subTest(job='defined', assignee=None): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_can_call_function_for_job_worker_in_org__allow_task_assigned_worker_with_task_request(self): + self._set_task_assignee(self.task['id'], self.user.id, org_id=self.org['id']) + + data = self.common_data.copy() + with self.subTest(job=None, assignee='task'): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_can_call_function_for_job_worker_in_org__deny_job_assigned_worker_with_task_request(self): + self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + + data = self.common_data.copy() + with self.subTest(job=None, assignee='job'): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_can_call_function_for_job_worker_in_org__allow_job_assigned_worker_with_job_request(self): + self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + + data = self.common_data.copy() + data.update(self._get_valid_job_params()) + with self.subTest(job='defined', assignee='job'): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_can_check_job_boundaries_in_function_call__fail_for_frame_outside_job(self): + self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + + data = self.common_data.copy() + data.update(self._get_invalid_job_params()) + with self.subTest(job='defined', frame='outside'): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_can_check_job_boundaries_in_function_call__ok_for_frame_inside_job(self): + self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + + data = self.common_data.copy() + data.update(self._get_valid_job_params()) + with self.subTest(job='defined', frame='inside'): + response = self._post_request(self.function_name, self.user, data, + org_id=self.org['id']) + self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 9a9209311f2e..3427562a9633 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -8,6 +8,8 @@ from functools import wraps from enum import Enum from copy import deepcopy +import textwrap +from typing import Any, Dict, Optional import django_rq import requests @@ -16,16 +18,17 @@ from django.conf import settings from django.core.exceptions import ObjectDoesNotExist, ValidationError -from rest_framework import status, viewsets +from rest_framework import status, viewsets, serializers from rest_framework.response import Response import cvat.apps.dataset_manager as dm from cvat.apps.engine.frame_provider import FrameProvider -from cvat.apps.engine.models import Task as TaskModel +from cvat.apps.engine.models import Job, Task from cvat.apps.engine.serializers import LabeledDataSerializer from cvat.apps.engine.models import ShapeType, SourceType -from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiResponse, OpenApiParameter +from drf_spectacular.utils import (extend_schema, extend_schema_view, + OpenApiResponse, OpenApiParameter, inline_serializer) from drf_spectacular.types import OpenApiTypes class LambdaType(Enum): @@ -175,8 +178,13 @@ def to_dict(self): return response - def invoke(self, db_task, data): + def invoke(self, db_task: Task, data: Dict[str, Any], *, db_job: Optional[Job] = None): try: + if db_job is not None and db_job.get_task_id() != db_task.id: + raise ValidationError("Job task id does not match task id", + code=status.HTTP_400_BAD_REQUEST + ) + payload = {} data = {k: v for k,v in data.items() if v is not None} threshold = data.get("threshold") @@ -225,6 +233,16 @@ def invoke(self, db_task, data): if mapped_attr in task_attr_names: supported_attrs[func_label].update({ attr["name"]: task_attributes[mapped_label][mapped_attr] }) + # Check job frame boundaries + for key, desc in ( + ('frame', 'frame'), + ('frame0', 'start frame'), + ('frame1', 'end frame'), + ): + if key in data and db_job and not db_job.segment.contains_frame(data[key]): + raise ValidationError(f"The {desc} is outside the job range", + code=status.HTTP_400_BAD_REQUEST) + if self.kind == LambdaType.DETECTOR: payload.update({ "image": self._get_image(db_task, data["frame"], quality) @@ -647,7 +665,7 @@ def _call_reid(function, db_task, quality, threshold, max_distance): @staticmethod def __call__(function, task, quality, cleanup, **kwargs): # TODO: need logging - db_task = TaskModel.objects.get(pk=task) + db_task = Task.objects.get(pk=task) if cleanup: dm.task.delete_task_data(db_task.id) db_labels = (db_task.project.label_set if db_task.project_id else db_task.label_set).prefetch_related("attributespec_set").all() @@ -685,7 +703,7 @@ def func_wrapper(*args, **kwargs): status_code = status.HTTP_500_INTERNAL_SERVER_ERROR data = str(err) except ValidationError as err: - status_code = err.code + status_code = err.code or status.HTTP_400_BAD_REQUEST data = err.message except ObjectDoesNotExist as err: status_code = status.HTTP_400_BAD_REQUEST @@ -725,12 +743,34 @@ def retrieve(self, request, func_id): gateway = LambdaGateway() return gateway.get(func_id).to_dict() + @extend_schema(description=textwrap.dedent("""\ + Allows to execute a function for immediate computation. + + Intended for short-lived executions, useful for interactive calls. + + When executed for interactive annotation, the job id must be specified + in the 'job' input field. The task id is not required in this case, + but if it is specified, it must match the job task id. + """), + request=inline_serializer("OnlineFunctionCall", fields={ + "job": serializers.IntegerField(required=False), + "task": serializers.IntegerField(required=False), + }), + responses=OpenApiResponse(description="Returns function invocation results") + ) @return_response() def call(self, request, func_id): self.check_object_permissions(request, func_id) try: - task_id = request.data['task'] - db_task = TaskModel.objects.get(pk=task_id) + job_id = request.data.get('job') + job = None + if job_id is not None: + job = Job.objects.get(id=job_id) + task_id = job.get_task_id() + else: + task_id = request.data['task'] + + db_task = Task.objects.get(pk=task_id) except (KeyError, ObjectDoesNotExist) as err: raise ValidationError( '`{}` lambda function was run '.format(func_id) + @@ -740,7 +780,7 @@ def call(self, request, func_id): gateway = LambdaGateway() lambda_func = gateway.get(func_id) - return lambda_func.invoke(db_task, request.data) + return lambda_func.invoke(db_task, request.data, db_job=job) @extend_schema(tags=['lambda']) @extend_schema_view(