From 69fb309ec3494307854ecd2df91dc65b65f4c516 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Fri, 25 Aug 2023 21:16:35 +0200 Subject: [PATCH] feat: Update Tags CRUD API (#24839) --- .../src/features/tags/TagModal.test.tsx | 46 +++ .../src/features/tags/TagModal.tsx | 321 ++++++++++++++++++ superset-frontend/src/features/tags/tags.ts | 10 + superset-frontend/src/pages/Tags/index.tsx | 77 ++++- superset-frontend/src/views/CRUD/types.ts | 2 + superset/daos/tag.py | 44 +++ superset/tags/api.py | 141 +++++++- superset/tags/commands/create.py | 46 ++- superset/tags/commands/exceptions.py | 8 +- superset/tags/commands/update.py | 78 +++++ superset/tags/schemas.py | 16 +- tests/integration_tests/tags/api_tests.py | 44 +++ tests/unit_tests/dao/tag_test.py | 28 ++ tests/unit_tests/tags/__init__.py | 0 tests/unit_tests/tags/commands/create_test.py | 110 ++++++ tests/unit_tests/tags/commands/update_test.py | 160 +++++++++ 16 files changed, 1109 insertions(+), 22 deletions(-) create mode 100644 superset-frontend/src/features/tags/TagModal.test.tsx create mode 100644 superset-frontend/src/features/tags/TagModal.tsx create mode 100644 superset/tags/commands/update.py create mode 100644 tests/unit_tests/tags/__init__.py create mode 100644 tests/unit_tests/tags/commands/create_test.py create mode 100644 tests/unit_tests/tags/commands/update_test.py diff --git a/superset-frontend/src/features/tags/TagModal.test.tsx b/superset-frontend/src/features/tags/TagModal.test.tsx new file mode 100644 index 0000000000000..a033b44cec257 --- /dev/null +++ b/superset-frontend/src/features/tags/TagModal.test.tsx @@ -0,0 +1,46 @@ +import React from 'react'; +import { render, screen } from 'spec/helpers/testing-library'; +import TagModal from 'src/features/tags/TagModal'; +import fetchMock from 'fetch-mock'; +import { Tag } from 'src/views/CRUD/types'; + +const mockedProps = { + onHide: () => {}, + refreshData: () => {}, + addSuccessToast: () => {}, + addDangerToast: () => {}, + show: true, +}; + +const fetchEditFetchObjects = `glob:*/api/v1/tag/get_objects/?tags=*`; + +test('should render', () => { + const { container } = render(); + expect(container).toBeInTheDocument(); +}); + +test('renders correctly in create mode', () => { + const { getByPlaceholderText, getByText } = render( + , + ); + + expect(getByPlaceholderText('Name of your tag')).toBeInTheDocument(); + expect(getByText('Create Tag')).toBeInTheDocument(); +}); + +test('renders correctly in edit mode', () => { + fetchMock.get(fetchEditFetchObjects, [[]]); + const editTag: Tag = { + id: 1, + name: 'Test Tag', + description: 'A test tag', + type: 'dashboard', + changed_on_delta_humanized: '', + created_by: {}, + }; + + render(); + expect(screen.getByPlaceholderText(/name of your tag/i)).toHaveValue( + editTag.name, + ); +}); diff --git a/superset-frontend/src/features/tags/TagModal.tsx b/superset-frontend/src/features/tags/TagModal.tsx new file mode 100644 index 0000000000000..bbe32102c6980 --- /dev/null +++ b/superset-frontend/src/features/tags/TagModal.tsx @@ -0,0 +1,321 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import React, { ChangeEvent, useState, useEffect } from 'react'; +import rison from 'rison'; +import Modal from 'src/components/Modal'; +import AsyncSelect from 'src/components/Select/AsyncSelect'; +import { FormLabel } from 'src/components/Form'; +import { t, SupersetClient } from '@superset-ui/core'; +import { Input } from 'antd'; +import { Divider } from 'src/components'; +import Button from 'src/components/Button'; +import { Tag } from 'src/views/CRUD/types'; +import { fetchObjects } from 'src/features/tags/tags'; + +interface TaggableResourceOption { + label: string; + value: number; + key: number; +} + +export enum TaggableResources { + Chart = 'chart', + Dashboard = 'dashboard', + SavedQuery = 'query', +} + +interface TagModalProps { + onHide: () => void; + refreshData: () => void; + addSuccessToast: (msg: string) => void; + addDangerToast: (msg: string) => void; + show: boolean; + editTag?: Tag | null; +} + +const TagModal: React.FC = ({ + show, + onHide, + editTag, + refreshData, + addSuccessToast, + addDangerToast, +}) => { + const [dashboardsToTag, setDashboardsToTag] = useState< + TaggableResourceOption[] + >([]); + const [chartsToTag, setChartsToTag] = useState([]); + const [savedQueriesToTag, setSavedQueriesToTag] = useState< + TaggableResourceOption[] + >([]); + + const [tagName, setTagName] = useState(''); + const [description, setDescription] = useState(''); + + const isEditMode = !!editTag; + const modalTitle = isEditMode ? 'Edit Tag' : 'Create Tag'; + + const clearResources = () => { + setDashboardsToTag([]); + setChartsToTag([]); + setSavedQueriesToTag([]); + }; + + useEffect(() => { + const resourceMap: { [key: string]: TaggableResourceOption[] } = { + [TaggableResources.Dashboard]: [], + [TaggableResources.Chart]: [], + [TaggableResources.SavedQuery]: [], + }; + + const updateResourceOptions = ({ id, name, type }: Tag) => { + const resourceOptions = resourceMap[type]; + if (resourceOptions) { + resourceOptions.push({ + value: id, + label: name, + key: id, + }); + } + }; + clearResources(); + if (isEditMode) { + fetchObjects( + { tags: editTag.name, types: null }, + (data: Tag[]) => { + data.forEach(updateResourceOptions); + setDashboardsToTag(resourceMap[TaggableResources.Dashboard]); + setChartsToTag(resourceMap[TaggableResources.Chart]); + setSavedQueriesToTag(resourceMap[TaggableResources.SavedQuery]); + }, + (error: Response) => { + addDangerToast('Error Fetching Tagged Objects'); + }, + ); + setTagName(editTag.name); + setDescription(editTag.description); + } + }, [editTag]); + + const loadData = async ( + search: string, + page: number, + pageSize: number, + columns: string[], + filterColumn: string, + orderColumn: string, + endpoint: string, + ) => { + const queryParams = rison.encode({ + columns, + filters: [ + { + col: filterColumn, + opr: 'ct', + value: search, + }, + ], + page, + order_column: orderColumn, + }); + + const { json } = await SupersetClient.get({ + endpoint: `/api/v1/${endpoint}/?q=${queryParams}`, + }); + const { result, count } = json; + + return { + data: result.map((item: { id: number }) => ({ + value: item.id, + label: item[filterColumn], + })), + totalCount: count, + }; + }; + + const loadCharts = async (search: string, page: number, pageSize: number) => + loadData( + search, + page, + pageSize, + ['id', 'slice_name'], + 'slice_name', + 'slice_name', + 'chart', + ); + + const loadDashboards = async ( + search: string, + page: number, + pageSize: number, + ) => + loadData( + search, + page, + pageSize, + ['id', 'dashboard_title'], + 'dashboard_title', + 'dashboard_title', + 'dashboard', + ); + + const loadQueries = async (search: string, page: number, pageSize: number) => + loadData( + search, + page, + pageSize, + ['id', 'label'], + 'label', + 'label', + 'saved_query', + ); + + const handleOptionChange = (resource: TaggableResources, data: any) => { + if (resource === TaggableResources.Dashboard) setDashboardsToTag(data); + else if (resource === TaggableResources.Chart) setChartsToTag(data); + else if (resource === TaggableResources.SavedQuery) + setSavedQueriesToTag(data); + }; + + const handleTagNameChange = (ev: ChangeEvent) => + setTagName(ev.target.value); + const handleDescriptionChange = (ev: ChangeEvent) => + setDescription(ev.target.value); + + const onSave = () => { + const dashboards = dashboardsToTag.map(dash => ['dashboard', dash.value]); + const charts = chartsToTag.map(chart => ['chart', chart.value]); + const savedQueries = savedQueriesToTag.map(q => ['query', q.value]); + + if (isEditMode) { + SupersetClient.put({ + endpoint: `/api/v1/tag/${editTag.id}`, + jsonPayload: { + description, + name: tagName, + objects_to_tag: [...dashboards, ...charts, ...savedQueries], + }, + }).then(({ json = {} }) => { + refreshData(); + addSuccessToast(t('Tag updated')); + }); + } else { + SupersetClient.post({ + endpoint: `/api/v1/tag/`, + jsonPayload: { + description, + name: tagName, + objects_to_tag: [...dashboards, ...charts, ...savedQueries], + }, + }).then(({ json = {} }) => { + refreshData(); + addSuccessToast(t('Tag created')); + }); + } + onHide(); + }; + + return ( + { + setTagName(''); + setDescription(''); + setDashboardsToTag([]); + setChartsToTag([]); + setSavedQueriesToTag([]); + onHide(); + }} + show={show} + footer={ +
+ + +
+ } + > + <> + {t('Tag Name')} + + {t('Description')} + + + + handleOptionChange(TaggableResources.Dashboard, value) + } + header={{t('Dashboards')}} + allowClear + /> + handleOptionChange(TaggableResources.Chart, value)} + header={{t('Charts')}} + allowClear + /> + + handleOptionChange(TaggableResources.SavedQuery, value) + } + header={{t('Saved Queries')}} + allowClear + /> + +
+ ); +}; + +export default TagModal; diff --git a/superset-frontend/src/features/tags/tags.ts b/superset-frontend/src/features/tags/tags.ts index ff0b8f3a339d3..97b5b094b3dbe 100644 --- a/superset-frontend/src/features/tags/tags.ts +++ b/superset-frontend/src/features/tags/tags.ts @@ -55,6 +55,16 @@ export function fetchAllTags( .catch(response => error(response)); } +export function fetchSingleTag( + name: string, + callback: (json: JsonObject) => void, + error: (response: Response) => void, +) { + SupersetClient.get({ endpoint: `/api/v1/tag` }) + .then(({ json }) => callback(json)) + .catch(response => error(response)); +} + export function fetchTags( { objectType, diff --git a/superset-frontend/src/pages/Tags/index.tsx b/superset-frontend/src/pages/Tags/index.tsx index 03a2b1da9c884..fa623f03c5bea 100644 --- a/superset-frontend/src/pages/Tags/index.tsx +++ b/superset-frontend/src/pages/Tags/index.tsx @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +import React, { useMemo, useCallback, useState } from 'react'; import { isFeatureEnabled, FeatureFlag, t } from '@superset-ui/core'; -import React, { useMemo, useCallback } from 'react'; import { createFetchRelated, createErrorHandler, @@ -41,23 +41,9 @@ import { deleteTags } from 'src/features/tags/tags'; import { Tag as AntdTag } from 'antd'; import { Tag } from 'src/views/CRUD/types'; import TagCard from 'src/features/tags/TagCard'; +import TagModal from 'src/features/tags/TagModal'; import FaveStar from 'src/components/FaveStar'; -const emptyState = { - title: t('No Tags created'), - image: 'dashboard.svg', - description: - 'Create a new tag and assign it to existing entities like charts or dashboards', - buttonAction: () => {}, - // todo(hughhh): Add this back once Tag modal is functional - // buttonText: ( - // <> - // {' '} - // {'Create a new Tag'}{' '} - // - // ), -}; - const PAGE_SIZE = 25; interface TagListProps { @@ -90,6 +76,8 @@ function TagList(props: TagListProps) { refreshData, } = useListViewResource('tag', t('tag'), addDangerToast); + const [showTagModal, setShowTagModal] = useState(false); + const [tagToEdit, setTagToEdit] = useState(null); const tagIds = useMemo(() => tags.map(c => c.id), [tags]); const [saveFavoriteStatus, favoriteStatus] = useFavoriteStatus( 'tag', @@ -101,6 +89,7 @@ function TagList(props: TagListProps) { const userKey = dangerouslyGetItemDoNotUse(userId?.toString(), null); const canDelete = hasPerm('can_write'); + const canEdit = hasPerm('can_write'); const initialSort = [{ id: 'changed_on_delta_humanized', desc: true }]; @@ -114,6 +103,25 @@ function TagList(props: TagListProps) { refreshData(); } + const handleTagEdit = (tag: Tag) => { + setShowTagModal(true); + setTagToEdit(tag); + }; + + const emptyState = { + title: t('No Tags created'), + image: 'dashboard.svg', + description: + 'Create a new tag and assign it to existing entities like charts or dashboards', + buttonAction: () => setShowTagModal(true), + buttonText: ( + <> + {' '} + {'Create a new Tag'}{' '} + + ), + }; + const columns = useMemo( () => [ { @@ -175,6 +183,7 @@ function TagList(props: TagListProps) { Cell: ({ row: { original } }: any) => { const handleDelete = () => handleTagsDelete([original], addSuccessToast, addDangerToast); + const handleEdit = () => handleTagEdit(original); return ( {canDelete && ( @@ -206,6 +215,22 @@ function TagList(props: TagListProps) { )} )} + {canEdit && ( + + + + + + )} ); }, @@ -303,6 +328,7 @@ function TagList(props: TagListProps) { ); const subMenuButtons: SubMenuProps['buttons'] = []; + if (canDelete) { subMenuButtons.push({ name: t('Bulk select'), @@ -312,11 +338,30 @@ function TagList(props: TagListProps) { }); } + // render new 'New Tag' btn + subMenuButtons.push({ + name: t('New Tag'), + buttonStyle: 'primary', + 'data-test': 'bulk-select', + onClick: () => setShowTagModal(true), + }); + const handleBulkDelete = (tagsToDelete: Tag[]) => handleTagsDelete(tagsToDelete, addSuccessToast, addDangerToast); return ( <> + { + setShowTagModal(false); + setTagToEdit(null); + }} + editTag={tagToEdit} + refreshData={refreshData} + addSuccessToast={addSuccessToast} + addDangerToast={addDangerToast} + /> & diff --git a/superset/daos/tag.py b/superset/daos/tag.py index 8e4437f49d447..1a0843cc0e0a0 100644 --- a/superset/daos/tag.py +++ b/superset/daos/tag.py @@ -29,6 +29,7 @@ from superset.models.slice import Slice from superset.models.sql_lab import SavedQuery from superset.tags.commands.exceptions import TagNotFoundError +from superset.tags.commands.utils import to_object_type from superset.tags.models import ( get_tag, ObjectTypes, @@ -363,3 +364,46 @@ def favorited_ids(tags: list[Tag]) -> list[int]: ) .all() ] + + @staticmethod + def create_tag_relationship( + objects_to_tag: list[tuple[ObjectTypes, int]], tag: Tag + ) -> None: + """ + Creates a tag relationship between the given objects and the specified tag. + This function iterates over a list of objects, each specified by a type + and an id, and creates a TaggedObject for each one, associating it with + the provided tag. All created TaggedObjects are collected in a list. + Args: + objects_to_tag (List[Tuple[ObjectTypes, int]]): A list of tuples, each + containing an ObjectType and an id, representing the objects to be tagged. + + tag (Tag): The tag to be associated with the specified objects. + Returns: + None. + """ + tagged_objects = [] + if not tag: + raise TagNotFoundError() + + current_tagged_objects = { + (obj.object_type, obj.object_id) for obj in tag.objects + } + updated_tagged_objects = { + (to_object_type(obj[0]), obj[1]) for obj in objects_to_tag + } + tagged_objects_to_delete = current_tagged_objects - updated_tagged_objects + + for object_type, object_id in updated_tagged_objects: + # create rows for new objects, and skip tags that already exist + if (object_type, object_id) not in current_tagged_objects: + tagged_objects.append( + TaggedObject(object_id=object_id, object_type=object_type, tag=tag) + ) + + for object_type, object_id in tagged_objects_to_delete: + # delete objects that were removed + TagDAO.delete_tagged_object(object_type, object_id, tag.name) # type: ignore + + db.session.add_all(tagged_objects) + db.session.commit() diff --git a/superset/tags/api.py b/superset/tags/api.py index a12461e8e4635..a760b33921c03 100644 --- a/superset/tags/api.py +++ b/superset/tags/api.py @@ -20,12 +20,16 @@ from flask import request, Response from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface +from marshmallow import ValidationError from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.daos.tag import TagDAO from superset.exceptions import MissingUserContextException from superset.extensions import event_logger -from superset.tags.commands.create import CreateCustomTagCommand +from superset.tags.commands.create import ( + CreateCustomTagCommand, + CreateCustomTagWithRelationshipsCommand, +) from superset.tags.commands.delete import DeleteTaggedObjectCommand, DeleteTagsCommand from superset.tags.commands.exceptions import ( TagCreateFailedError, @@ -34,7 +38,9 @@ TaggedObjectNotFoundError, TagInvalidError, TagNotFoundError, + TagUpdateFailedError, ) +from superset.tags.commands.update import UpdateTagCommand from superset.tags.models import ObjectTypes, Tag from superset.tags.schemas import ( delete_tags_schema, @@ -42,6 +48,7 @@ TaggedObjectEntityResponseSchema, TagGetResponseSchema, TagPostSchema, + TagPutSchema, ) from superset.views.base_api import ( BaseSupersetModelRestApi, @@ -77,6 +84,7 @@ class TagRestApi(BaseSupersetModelRestApi): "id", "name", "type", + "description", "changed_by.first_name", "changed_by.last_name", "changed_on_delta_humanized", @@ -90,6 +98,7 @@ class TagRestApi(BaseSupersetModelRestApi): "id", "name", "type", + "description", "changed_by.first_name", "changed_by.last_name", "changed_on_delta_humanized", @@ -108,6 +117,7 @@ class TagRestApi(BaseSupersetModelRestApi): allowed_rel_fields = {"created_by"} add_model_schema = TagPostSchema() + edit_model_schema = TagPutSchema() tag_get_response_schema = TagGetResponseSchema() object_entity_response_schema = TaggedObjectEntityResponseSchema() @@ -131,6 +141,131 @@ def __repr__(self) -> str: f'{self.appbuilder.app.config["VERSION_SHA"]}' ) + @expose("/", methods=("POST",)) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post", + log_to_statsd=False, + ) + def post(self) -> Response: + """Creates a new Tags and tag items + --- + post: + description: >- + Create a new Tag + requestBody: + description: Tag schema + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/{{self.__class__.__name__}}.post' + responses: + 201: + description: Tag added + content: + application/json: + schema: + type: object + properties: + id: + type: number + result: + $ref: '#/components/schemas/{{self.__class__.__name__}}.post' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + item = self.add_model_schema.load(request.json) + except ValidationError as error: + return self.response_400(message=error.messages) + try: + CreateCustomTagWithRelationshipsCommand(item).run() + return self.response(201) + except TagInvalidError as ex: + return self.response_422(message=ex.normalized_messages()) + except TagCreateFailedError as ex: + logger.error( + "Error creating model %s: %s", + self.__class__.__name__, + str(ex), + exc_info=True, + ) + return self.response_500(message=str(ex)) + + @expose("/", methods=("PUT",)) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", + log_to_statsd=False, + ) + def put(self, pk: int) -> Response: + """Changes a Tag + --- + put: + description: >- + Changes a Tag. + parameters: + - in: path + schema: + type: integer + name: pk + requestBody: + description: Chart schema + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/{{self.__class__.__name__}}.put' + responses: + 200: + description: Tag changed + content: + application/json: + schema: + type: object + properties: + id: + type: number + result: + $ref: '#/components/schemas/{{self.__class__.__name__}}.put' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + item = self.edit_model_schema.load(request.json) + # This validates custom Schema with custom validations + except ValidationError as error: + return self.response_400(message=error.messages) + item = request.json + try: + changed_model = UpdateTagCommand(pk, item).run() + response = self.response(200, id=changed_model.id, result=item) + except TagUpdateFailedError as ex: + response = self.response_422(message=str(ex)) + + return response + @expose("///", methods=("POST",)) @protect() @safe @@ -201,7 +336,7 @@ def add_objects(self, object_type: ObjectTypes, object_id: int) -> Response: str(ex), exc_info=True, ) - return self.response_422(message=str(ex)) + return self.response_500(message=str(ex)) @expose("////", methods=("DELETE",)) @protect() @@ -387,7 +522,7 @@ def get_objects(self) -> Response: str(ex), exc_info=True, ) - return self.response_422(message=str(ex)) + return self.response_500(message=str(ex)) @expose("/favorite_status/", methods=("GET",)) @protect() diff --git a/superset/tags/commands/create.py b/superset/tags/commands/create.py index 7e9f040015e2b..5c30b548bda3a 100644 --- a/superset/tags/commands/create.py +++ b/superset/tags/commands/create.py @@ -15,13 +15,15 @@ # specific language governing permissions and limitations # under the License. import logging +from typing import Any +from superset import db from superset.commands.base import BaseCommand, CreateMixin from superset.daos.exceptions import DAOCreateFailedError from superset.daos.tag import TagDAO from superset.tags.commands.exceptions import TagCreateFailedError, TagInvalidError from superset.tags.commands.utils import to_object_type -from superset.tags.models import ObjectTypes +from superset.tags.models import ObjectTypes, TagTypes logger = logging.getLogger(__name__) @@ -60,3 +62,45 @@ def validate(self) -> None: ) if exceptions: raise TagInvalidError(exceptions=exceptions) + + +class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand): + def __init__(self, data: dict[str, Any]): + self._tag = data["name"] + self._objects_to_tag = data.get("objects_to_tag") + self._description = data.get("description") + + def run(self) -> None: + self.validate() + try: + tag = TagDAO.get_by_name(self._tag.strip(), TagTypes.custom) + if self._objects_to_tag: + TagDAO.create_tag_relationship( + objects_to_tag=self._objects_to_tag, tag=tag + ) + + if self._description: + tag.description = self._description + db.session.commit() + + except DAOCreateFailedError as ex: + logger.exception(ex.exception) + raise TagCreateFailedError() from ex + + def validate(self) -> None: + exceptions = [] + # Validate object_id + if self._objects_to_tag: + if any(obj_id == 0 for obj_type, obj_id in self._objects_to_tag): + exceptions.append(TagInvalidError()) + + # Validate object type + for obj_type, obj_id in self._objects_to_tag: + object_type = to_object_type(obj_type) + if not object_type: + exceptions.append( + TagInvalidError(f"invalid object type {object_type}") + ) + + if exceptions: + raise TagInvalidError(exceptions=exceptions) diff --git a/superset/tags/commands/exceptions.py b/superset/tags/commands/exceptions.py index 9847c949bf7ec..6778c8e221a1f 100644 --- a/superset/tags/commands/exceptions.py +++ b/superset/tags/commands/exceptions.py @@ -23,6 +23,8 @@ CommandInvalidError, CreateFailedError, DeleteFailedError, + ObjectNotFoundError, + UpdateFailedError, ) @@ -34,6 +36,10 @@ class TagCreateFailedError(CreateFailedError): message = _("Tag could not be created.") +class TagUpdateFailedError(UpdateFailedError): + message = _("Tag could not be updated.") + + class TagDeleteFailedError(DeleteFailedError): message = _("Tag could not be deleted.") @@ -42,7 +48,7 @@ class TaggedObjectDeleteFailedError(DeleteFailedError): message = _("Tagged Object could not be deleted.") -class TagNotFoundError(CommandException): +class TagNotFoundError(ObjectNotFoundError): def __init__(self, tag_name: Optional[str] = None) -> None: message = "Tag not found." if tag_name: diff --git a/superset/tags/commands/update.py b/superset/tags/commands/update.py new file mode 100644 index 0000000000000..a13e4e8e7bbb0 --- /dev/null +++ b/superset/tags/commands/update.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Optional + +from flask_appbuilder.models.sqla import Model + +from superset import db +from superset.commands.base import BaseCommand, UpdateMixin +from superset.daos.tag import TagDAO +from superset.tags.commands.exceptions import TagInvalidError, TagNotFoundError +from superset.tags.commands.utils import to_object_type +from superset.tags.models import Tag + +logger = logging.getLogger(__name__) + + +class UpdateTagCommand(UpdateMixin, BaseCommand): + def __init__(self, model_id: int, data: dict[str, Any]): + self._model_id = model_id + self._properties = data.copy() + self._model: Optional[Tag] = None + + def run(self) -> Model: + self.validate() + if self._model: + if self._properties.get("objects_to_tag"): + # todo(hugh): can this manage duplication + TagDAO.create_tag_relationship( + objects_to_tag=self._properties["objects_to_tag"], + tag=self._model, + ) + if description := self._properties.get("description"): + self._model.description = description + if tag_name := self._properties.get("name"): + self._model.name = tag_name + + db.session.add(self._model) + db.session.commit() + + return self._model + + def validate(self) -> None: + exceptions = [] + # Validate/populate model exists + self._model = TagDAO.find_by_id(self._model_id) + if not self._model: + raise TagNotFoundError() + + # Validate object_id + if objects_to_tag := self._properties.get("objects_to_tag"): + if any(obj_id == 0 for obj_type, obj_id in objects_to_tag): + exceptions.append(TagInvalidError(" invalid object_id")) + + # Validate object type + for obj_type, obj_id in objects_to_tag: + object_type = to_object_type(obj_type) + if not object_type: + exceptions.append( + TagInvalidError(f"invalid object type {object_type}") + ) + + if exceptions: + raise TagInvalidError(exceptions=exceptions) diff --git a/superset/tags/schemas.py b/superset/tags/schemas.py index f519901a8bb84..89f15d4bf8f54 100644 --- a/superset/tags/schemas.py +++ b/superset/tags/schemas.py @@ -55,4 +55,18 @@ class TagGetResponseSchema(Schema): class TagPostSchema(Schema): - tags = fields.List(fields.String()) + name = fields.String() + description = fields.String(required=False) + # resource id's to tag with tag + objects_to_tag = fields.List( + fields.Tuple((fields.String(), fields.Int())), required=False + ) + + +class TagPutSchema(Schema): + name = fields.String() + description = fields.String(required=False) + # resource id's to tag with tag + objects_to_tag = fields.List( + fields.Tuple((fields.String(), fields.Int())), required=False + ) diff --git a/tests/integration_tests/tags/api_tests.py b/tests/integration_tests/tags/api_tests.py index d1231db97ea55..e0f4de87eb750 100644 --- a/tests/integration_tests/tags/api_tests.py +++ b/tests/integration_tests/tags/api_tests.py @@ -53,6 +53,7 @@ "id", "name", "type", + "description", "changed_by.first_name", "changed_by.last_name", "changed_on_delta_humanized", @@ -457,3 +458,46 @@ def test_delete_favorite_tag_user_not_found(self, flask_g): rv = self.client.delete(uri, follow_redirects=True) self.assertEqual(rv.status_code, 422) + + @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") + def test_post_tag(self): + self.login(username="admin") + uri = f"api/v1/tag/" + dashboard = ( + db.session.query(Dashboard) + .filter(Dashboard.dashboard_title == "World Bank's Data") + .first() + ) + rv = self.client.post( + uri, + json={"name": "my_tag", "objects_to_tag": [["dashboard", dashboard.id]]}, + ) + + self.assertEqual(rv.status_code, 201) + user_id = self.get_user(username="admin").get_id() + tag = ( + db.session.query(Tag) + .filter(Tag.name == "my_tag", Tag.type == TagTypes.custom) + .one_or_none() + ) + assert tag is not None + + @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") + @pytest.mark.usefixtures("create_tags") + def test_put_tag(self): + self.login(username="admin") + + tag_to_update = db.session.query(Tag).first() + uri = f"api/v1/tag/{tag_to_update.id}" + rv = self.client.put( + uri, json={"name": "new_name", "description": "new description"} + ) + + self.assertEqual(rv.status_code, 200) + + tag = ( + db.session.query(Tag) + .filter(Tag.name == "new_name", Tag.description == "new description") + .one_or_none() + ) + assert tag is not None diff --git a/tests/unit_tests/dao/tag_test.py b/tests/unit_tests/dao/tag_test.py index 3f5666d2b54a0..476c51e45db31 100644 --- a/tests/unit_tests/dao/tag_test.py +++ b/tests/unit_tests/dao/tag_test.py @@ -144,3 +144,31 @@ def test_user_favorite_tag_exc_raise(mocker): mock_session.commit.side_effect = Exception("DB Error") with pytest.raises(Exception): TagDAO.remove_user_favorite_tag(1) + + +def test_create_tag_relationship(mocker): + from superset.daos.tag import TagDAO + from superset.tags.models import ( # Assuming these are defined in the same module + ObjectTypes, + TaggedObject, + ) + + mock_session = mocker.patch("superset.daos.tag.db.session") + + # Define a list of objects to tag + objects_to_tag = [ + (ObjectTypes.query, 1), + (ObjectTypes.chart, 2), + (ObjectTypes.dashboard, 3), + ] + + # Call the function + tag = TagDAO.get_by_name("test_tag") + TagDAO.create_tag_relationship(objects_to_tag, tag) + + # Verify that the correct number of TaggedObjects are added to the session + assert mock_session.add_all.call_count == 1 + assert len(mock_session.add_all.call_args[0][0]) == len(objects_to_tag) + + # Verify that commit is called + mock_session.commit.assert_called_once() diff --git a/tests/unit_tests/tags/__init__.py b/tests/unit_tests/tags/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/unit_tests/tags/commands/create_test.py b/tests/unit_tests/tags/commands/create_test.py new file mode 100644 index 0000000000000..a188625b403f5 --- /dev/null +++ b/tests/unit_tests/tags/commands/create_test.py @@ -0,0 +1,110 @@ +import pytest +from sqlalchemy.orm.session import Session + +from superset.utils.core import DatasourceType + + +@pytest.fixture +def session_with_data(session: Session): + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.models.core import Database + from superset.models.dashboard import Dashboard + from superset.models.slice import Slice + from superset.models.sql_lab import Query, SavedQuery + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + slice_obj = Slice( + id=1, + datasource_id=1, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_perm_table", + slice_name="slice_name", + ) + + db = Database(database_name="my_database", sqlalchemy_uri="postgresql://") + + columns = [ + TableColumn(column_name="a", type="INTEGER"), + ] + + saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo") + + dashboard_obj = Dashboard( + id=100, + dashboard_title="test_dashboard", + slug="test_slug", + slices=[], + published=True, + ) + + session.add(slice_obj) + session.add(db) + session.add(saved_query) + session.add(dashboard_obj) + session.commit() + yield session + + +def test_create_command_success(session_with_data: Session): + from superset.connectors.sqla.models import SqlaTable + from superset.daos.tag import TagDAO + from superset.models.dashboard import Dashboard + from superset.models.slice import Slice + from superset.models.sql_lab import Query, SavedQuery + from superset.tags.commands.create import CreateCustomTagWithRelationshipsCommand + from superset.tags.models import ObjectTypes, TaggedObject + + # Define a list of objects to tag + query = session_with_data.query(SavedQuery).first() + chart = session_with_data.query(Slice).first() + dashboard = session_with_data.query(Dashboard).first() + + objects_to_tag = [ + (ObjectTypes.query, query.id), + (ObjectTypes.chart, chart.id), + (ObjectTypes.dashboard, dashboard.id), + ] + + CreateCustomTagWithRelationshipsCommand( + data={"name": "test_tag", "objects_to_tag": objects_to_tag} + ).run() + + assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag) + for object_type, object_id in objects_to_tag: + assert ( + session_with_data.query(TaggedObject) + .filter( + TaggedObject.object_type == object_type, + TaggedObject.object_id == object_id, + ) + .one_or_none() + is not None + ) + + +def test_create_command_failed_validate(session_with_data: Session): + from superset.connectors.sqla.models import SqlaTable + from superset.daos.tag import TagDAO + from superset.models.dashboard import Dashboard + from superset.models.slice import Slice + from superset.models.sql_lab import Query, SavedQuery + from superset.tags.commands.create import CreateCustomTagWithRelationshipsCommand + from superset.tags.commands.exceptions import TagInvalidError + from superset.tags.models import ObjectTypes, TaggedObject + + query = session_with_data.query(SavedQuery).first() + chart = session_with_data.query(Slice).first() + dashboard = session_with_data.query(Dashboard).first() + + objects_to_tag = [ + (ObjectTypes.query, query.id), + (ObjectTypes.chart, chart.id), + (ObjectTypes.dashboard, 0), + ] + + with pytest.raises(TagInvalidError): + CreateCustomTagWithRelationshipsCommand( + data={"name": "test_tag", "objects_to_tag": objects_to_tag} + ).run() diff --git a/tests/unit_tests/tags/commands/update_test.py b/tests/unit_tests/tags/commands/update_test.py new file mode 100644 index 0000000000000..2c2454547eb17 --- /dev/null +++ b/tests/unit_tests/tags/commands/update_test.py @@ -0,0 +1,160 @@ +import pytest +from sqlalchemy.orm.session import Session + +from superset.utils.core import DatasourceType + + +@pytest.fixture +def session_with_data(session: Session): + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.models.core import Database + from superset.models.dashboard import Dashboard + from superset.models.slice import Slice + from superset.models.sql_lab import Query, SavedQuery + from superset.tags.models import Tag + + engine = session.get_bind() + Tag.metadata.create_all(engine) # pylint: disable=no-member + + slice_obj = Slice( + id=1, + datasource_id=1, + datasource_type=DatasourceType.TABLE, + datasource_name="tmp_perm_table", + slice_name="slice_name", + ) + + db = Database(database_name="my_database", sqlalchemy_uri="postgresql://") + + columns = [ + TableColumn(column_name="a", type="INTEGER"), + ] + + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=columns, + metrics=[], + database=db, + ) + + dashboard_obj = Dashboard( + id=100, + dashboard_title="test_dashboard", + slug="test_slug", + slices=[], + published=True, + ) + + saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo") + + tag = Tag(name="test_name", description="test_description") + + session.add(slice_obj) + session.add(dashboard_obj) + session.add(tag) + session.commit() + yield session + + +def test_update_command_success(session_with_data: Session): + from superset.daos.tag import TagDAO + from superset.models.dashboard import Dashboard + from superset.tags.commands.update import UpdateTagCommand + from superset.tags.models import ObjectTypes, TaggedObject + + dashboard = session_with_data.query(Dashboard).first() + + objects_to_tag = [ + (ObjectTypes.dashboard, dashboard.id), + ] + + tag_to_update = TagDAO.find_by_name("test_name") + changed_model = UpdateTagCommand( + tag_to_update.id, + { + "name": "new_name", + "description": "new_description", + "objects_to_tag": objects_to_tag, + }, + ).run() + + updated_tag = TagDAO.find_by_name("new_name") + assert updated_tag is not None + assert updated_tag.description == "new_description" + assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag) + + +def test_update_command_success_duplicates(session_with_data: Session): + from superset.daos.tag import TagDAO + from superset.models.dashboard import Dashboard + from superset.models.slice import Slice + from superset.tags.commands.create import CreateCustomTagWithRelationshipsCommand + from superset.tags.commands.update import UpdateTagCommand + from superset.tags.models import ObjectTypes, TaggedObject + + dashboard = session_with_data.query(Dashboard).first() + chart = session_with_data.query(Slice).first() + + objects_to_tag = [ + (ObjectTypes.dashboard, dashboard.id), + ] + + CreateCustomTagWithRelationshipsCommand( + data={"name": "test_tag", "objects_to_tag": objects_to_tag} + ).run() + + tag_to_update = TagDAO.find_by_name("test_tag") + + objects_to_tag = [ + (ObjectTypes.chart, chart.id), + ] + changed_model = UpdateTagCommand( + tag_to_update.id, + { + "name": "new_name", + "description": "new_description", + "objects_to_tag": objects_to_tag, + }, + ).run() + + updated_tag = TagDAO.find_by_name("new_name") + assert updated_tag is not None + assert updated_tag.description == "new_description" + assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag) + assert changed_model.objects[0].object_id == chart.id + + +def test_update_command_failed_validation(session_with_data: Session): + from superset.daos.tag import TagDAO + from superset.models.dashboard import Dashboard + from superset.models.slice import Slice + from superset.tags.commands.create import CreateCustomTagWithRelationshipsCommand + from superset.tags.commands.exceptions import TagInvalidError + from superset.tags.commands.update import UpdateTagCommand + from superset.tags.models import ObjectTypes, TaggedObject + + dashboard = session_with_data.query(Dashboard).first() + chart = session_with_data.query(Slice).first() + objects_to_tag = [ + (ObjectTypes.chart, chart.id), + ] + + CreateCustomTagWithRelationshipsCommand( + data={"name": "test_tag", "objects_to_tag": objects_to_tag} + ).run() + + tag_to_update = TagDAO.find_by_name("test_tag") + + objects_to_tag = [ + (0, dashboard.id), # type: ignore + ] + + with pytest.raises(TagInvalidError): + UpdateTagCommand( + tag_to_update.id, + { + "name": "new_name", + "description": "new_description", + "objects_to_tag": objects_to_tag, + }, + ).run()