Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tag): fast follow for Tags flatten api + update client with generator + some bug fixes #25309

Merged
merged 5 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions superset-frontend/src/features/tags/BulkTagModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,19 @@ const BulkTagModal: React.FC<BulkTagModalProps> = ({
addDangerToast,
}) => {
useEffect(() => {}, []);
const [tags, setTags] = useState<TaggableResourceOption[]>([]);

const onSave = async () => {
await SupersetClient.post({
endpoint: `/api/v1/tag/bulk_create`,
jsonPayload: {
tags: tags.map(tag => tag.value),
objects_to_tag: selected.map(item => [resourceName, +item.original.id]),
tags: tags.map(tag => ({
name: tag.value,
objects_to_tag: selected.map(item => [
resourceName,
+item.original.id,
eschutho marked this conversation as resolved.
Show resolved Hide resolved
]),
})),
},
})
.then(({ json = {} }) => {
Expand All @@ -66,8 +72,6 @@ const BulkTagModal: React.FC<BulkTagModalProps> = ({
setTags([]);
};

const [tags, setTags] = useState<TaggableResourceOption[]>([]);

return (
<Modal
title={t('Bulk tag')}
Expand Down
1 change: 0 additions & 1 deletion superset/daos/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,4 +412,3 @@ def create_tag_relationship(
)

db.session.add_all(tagged_objects)
db.session.commit()
5 changes: 4 additions & 1 deletion superset/tags/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ def bulk_create(self) -> Response:
try:
for tag in item.get("tags"):
tagged_item: dict[str, Any] = self.add_model_schema.load(
{"name": tag, "objects_to_tag": item.get("objects_to_tag")}
{
"name": tag.get("name"),
"objects_to_tag": tag.get("objects_to_tag"),
}
)
CreateCustomTagWithRelationshipsCommand(
tagged_item, bulk_create=True
Expand Down
22 changes: 19 additions & 3 deletions superset/tags/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import logging
from typing import Any

from superset import db
from superset import db, security_manager
from superset.commands.base import BaseCommand, CreateMixin
from superset.daos.exceptions import DAOCreateFailedError
from superset.daos.tag import TagDAO
from superset.exceptions import SupersetSecurityException
from superset.tags.commands.exceptions import TagCreateFailedError, TagInvalidError
from superset.tags.commands.utils import to_object_type
from superset.tags.commands.utils import to_object_model, to_object_type
from superset.tags.models import ObjectTypes, TagTypes

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(self, data: dict[str, Any], bulk_create: bool = False):

def run(self) -> None:
self.validate()

try:
tag = TagDAO.get_by_name(self._tag.strip(), TagTypes.custom)
if self._objects_to_tag:
Expand All @@ -84,7 +86,8 @@ def run(self) -> None:

if self._description:
tag.description = self._description
db.session.commit()

db.session.commit()

except DAOCreateFailedError as ex:
logger.exception(ex.exception)
Expand All @@ -98,12 +101,25 @@ def validate(self) -> None:
exceptions.append(TagInvalidError())

# Validate object type
skipped_tagged_objects: list[tuple[str, int]] = []
for obj_type, obj_id in self._objects_to_tag:
skipped_tagged_objects = []
object_type = to_object_type(obj_type)

if not object_type:
exceptions.append(
TagInvalidError(f"invalid object type {object_type}")
)
try:
model = to_object_model(object_type, obj_id) # type: ignore
security_manager.raise_for_ownership(model)
except SupersetSecurityException:
# skip the object if the user doesn't have access
skipped_tagged_objects.append((obj_type, obj_id))

self._objects_to_tag = set(self._objects_to_tag) - set(
skipped_tagged_objects
)

if exceptions:
raise TagInvalidError(exceptions=exceptions)
18 changes: 18 additions & 0 deletions superset/tags/commands/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

from typing import Optional, Union

from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO
from superset.daos.query import SavedQueryDAO
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
from superset.tags.models import ObjectTypes


Expand All @@ -27,3 +33,15 @@ def to_object_type(object_type: Union[ObjectTypes, int, str]) -> Optional[Object
if object_type in [type_.value, type_.name]:
return type_
return None


def to_object_model(
object_type: ObjectTypes, object_id: int
) -> Optional[Union[Dashboard, SavedQuery, Slice]]:
if ObjectTypes.dashboard == object_type:
return DashboardDAO.find_by_id(object_id)
if ObjectTypes.query == object_type:
return SavedQueryDAO.find_by_id(object_id)
if ObjectTypes.chart == object_type:
return ChartDAO.find_by_id(object_id)
return None
22 changes: 8 additions & 14 deletions superset/tags/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,21 @@ class TagGetResponseSchema(Schema):
type = fields.String()


class TagPostSchema(Schema):
class TagObjectSchema(Schema):
name = fields.String()
description = fields.String(required=False, allow_none=True)
# resource id's to tag with tag
objects_to_tag = fields.List(
fields.Tuple((fields.String(), fields.Int())), required=False
)


class TagPostBulkSchema(Schema):
tags = fields.List(fields.String())
# resource id's to tag with tag
objects_to_tag = fields.List(
fields.Tuple((fields.String(), fields.Int())), required=False
)
tags = fields.List(fields.Nested(TagObjectSchema))


class TagPutSchema(Schema):
name = fields.String()
description = fields.String(required=False, allow_none=True)
# resource id's to tag with tag
objects_to_tag = fields.List(
fields.Tuple((fields.String(), fields.Int())), required=False
)
class TagPostSchema(TagObjectSchema):
pass


class TagPutSchema(TagObjectSchema):
pass
24 changes: 19 additions & 5 deletions tests/integration_tests/tags/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,23 @@ def test_post_bulk_tag(self):
rv = self.client.post(
uri,
json={
"tags": ["tag1", "tag2", "tag3"],
"objects_to_tag": [["dashboard", dashboard.id], ["chart", chart.id]],
"tags": [
{
"name": "tag1",
"objects_to_tag": [
["dashboard", dashboard.id],
["chart", chart.id],
],
},
{
"name": "tag2",
"objects_to_tag": [["dashboard", dashboard.id]],
},
{
"name": "tag3",
"objects_to_tag": [["chart", chart.id]],
},
]
},
)

Expand All @@ -547,11 +562,10 @@ def test_post_bulk_tag(self):
TaggedObject.object_id == dashboard.id,
TaggedObject.object_type == ObjectTypes.dashboard,
)
assert tagged_objects.count() == 3
assert tagged_objects.count() == 2

tagged_objects = db.session.query(TaggedObject).filter(
# TaggedObject.tag_id.in_([tag.id for tag in tags]),
TaggedObject.object_id == chart.id,
TaggedObject.object_type == ObjectTypes.chart,
)
assert tagged_objects.count() == 3
assert tagged_objects.count() == 2
3 changes: 0 additions & 3 deletions tests/unit_tests/dao/tag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,3 @@ def test_create_tag_relationship(mocker):
# Verify that the correct number of TaggedObjects are added to the session
assert mock_session.add_all.call_count == 1
assert len(mock_session.add_all.call_args[0][0]) == len(objects_to_tag)

# Verify that commit is called
mock_session.commit.assert_called_once()
19 changes: 17 additions & 2 deletions tests/unit_tests/tags/commands/create_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session

from superset.utils.core import DatasourceType
Expand Down Expand Up @@ -47,7 +48,7 @@ def session_with_data(session: Session):
yield session


def test_create_command_success(session_with_data: Session):
def test_create_command_success(session_with_data: Session, mocker: MockFixture):
from superset.connectors.sqla.models import SqlaTable
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
Expand All @@ -61,6 +62,12 @@ def test_create_command_success(session_with_data: Session):
chart = session_with_data.query(Slice).first()
dashboard = session_with_data.query(Dashboard).first()

mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=query)

objects_to_tag = [
(ObjectTypes.query, query.id),
(ObjectTypes.chart, chart.id),
Expand All @@ -84,7 +91,9 @@ def test_create_command_success(session_with_data: Session):
)


def test_create_command_failed_validate(session_with_data: Session):
def test_create_command_failed_validate(
session_with_data: Session, mocker: MockFixture
):
from superset.connectors.sqla.models import SqlaTable
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
Expand All @@ -98,6 +107,12 @@ def test_create_command_failed_validate(session_with_data: Session):
chart = session_with_data.query(Slice).first()
dashboard = session_with_data.query(Dashboard).first()

mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=query)
mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=chart)

objects_to_tag = [
(ObjectTypes.query, query.id),
(ObjectTypes.chart, chart.id),
Expand Down
35 changes: 31 additions & 4 deletions tests/unit_tests/tags/commands/update_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session

from superset.utils.core import DatasourceType
Expand Down Expand Up @@ -56,13 +57,19 @@ def session_with_data(session: Session):
yield session


def test_update_command_success(session_with_data: Session):
def test_update_command_success(session_with_data: Session, mocker: MockFixture):
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
from superset.tags.commands.update import UpdateTagCommand
from superset.tags.models import ObjectTypes, TaggedObject

dashboard = session_with_data.query(Dashboard).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
mocker.patch(
"superset.daos.dashboard.DashboardDAO.find_by_id", return_value=dashboard
)

objects_to_tag = [
(ObjectTypes.dashboard, dashboard.id),
Expand All @@ -84,7 +91,9 @@ def test_update_command_success(session_with_data: Session):
assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)


def test_update_command_success_duplicates(session_with_data: Session):
def test_update_command_success_duplicates(
session_with_data: Session, mocker: MockFixture
):
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
Expand All @@ -95,6 +104,14 @@ def test_update_command_success_duplicates(session_with_data: Session):
dashboard = session_with_data.query(Dashboard).first()
chart = session_with_data.query(Slice).first()

mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
mocker.patch(
"superset.daos.dashboard.DashboardDAO.find_by_id", return_value=dashboard
)

objects_to_tag = [
(ObjectTypes.dashboard, dashboard.id),
]
Expand Down Expand Up @@ -124,21 +141,31 @@ def test_update_command_success_duplicates(session_with_data: Session):
assert changed_model.objects[0].object_id == chart.id


def test_update_command_failed_validation(session_with_data: Session):
def test_update_command_failed_validation(
session_with_data: Session, mocker: MockFixture
):
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.tags.commands.create import CreateCustomTagWithRelationshipsCommand
from superset.tags.commands.exceptions import TagInvalidError
from superset.tags.commands.update import UpdateTagCommand
from superset.tags.models import ObjectTypes, TaggedObject
from superset.tags.models import ObjectTypes

dashboard = session_with_data.query(Dashboard).first()
chart = session_with_data.query(Slice).first()
objects_to_tag = [
(ObjectTypes.chart, chart.id),
]

mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
mocker.patch(
"superset.daos.dashboard.DashboardDAO.find_by_id", return_value=dashboard
)

CreateCustomTagWithRelationshipsCommand(
data={"name": "test_tag", "objects_to_tag": objects_to_tag}
).run()
Expand Down