Skip to content

Commit

Permalink
fix(tags): fix clears delete on Tags Modal (apache#25470)
Browse files Browse the repository at this point in the history
Co-authored-by: Beto Dealmeida <roberto@dealmeida.net>
  • Loading branch information
hughhhh and betodealmeida authored Oct 5, 2023
1 parent 9578b39 commit 6cc2b36
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 60 deletions.
7 changes: 6 additions & 1 deletion superset/daos/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,12 @@ def create_tag_relationship(
updated_tagged_objects = {
(to_object_type(obj[0]), obj[1]) for obj in objects_to_tag
}
tagged_objects_to_delete = current_tagged_objects - updated_tagged_objects

tagged_objects_to_delete = (
current_tagged_objects
if not objects_to_tag
else current_tagged_objects - updated_tagged_objects
)

for object_type, object_id in updated_tagged_objects:
# create rows for new objects, and skip tags that already exist
Expand Down
61 changes: 24 additions & 37 deletions superset/tags/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,22 @@ def validate(self) -> None:

class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
def __init__(self, data: dict[str, Any], bulk_create: bool = False):
self._tag = data["name"]
self._objects_to_tag = data.get("objects_to_tag")
self._description = data.get("description")
self._properties = data.copy()
self._bulk_create = bulk_create

def run(self) -> None:
self.validate()

try:
tag = TagDAO.get_by_name(self._tag.strip(), TagTypes.custom)
if self._objects_to_tag:
TagDAO.create_tag_relationship(
objects_to_tag=self._objects_to_tag,
tag=tag,
bulk_create=self._bulk_create,
)
tag_name = self._properties["name"]
tag = TagDAO.get_by_name(tag_name.strip(), TagTypes.custom)
TagDAO.create_tag_relationship(
objects_to_tag=self._properties.get("objects_to_tag", []),
tag=tag,
bulk_create=self._bulk_create,
)

if self._description:
tag.description = self._description
tag.description = self._properties.get("description", "")

db.session.commit()

Expand All @@ -95,31 +92,21 @@ def run(self) -> None:

def validate(self) -> None:
exceptions = []
# Validate object_id
if self._objects_to_tag:
if any(obj_id == 0 for obj_type, obj_id in self._objects_to_tag):
exceptions.append(TagInvalidError())

# Validate object type
skipped_tagged_objects: list[tuple[str, int]] = []
for obj_type, obj_id in self._objects_to_tag:
skipped_tagged_objects = []
object_type = to_object_type(obj_type)

if not object_type:
exceptions.append(
TagInvalidError(f"invalid object type {object_type}")
)
try:
model = to_object_model(object_type, obj_id) # type: ignore
security_manager.raise_for_ownership(model)
except SupersetSecurityException:
# skip the object if the user doesn't have access
skipped_tagged_objects.append((obj_type, obj_id))

self._objects_to_tag = set(self._objects_to_tag) - set(
skipped_tagged_objects
)
objects_to_tag = set(self._properties.get("objects_to_tag", []))
skipped_tagged_objects: set[tuple[str, int]] = set()
for obj_type, obj_id in objects_to_tag:
object_type = to_object_type(obj_type)

if not object_type:
exceptions.append(TagInvalidError(f"invalid object type {object_type}"))
try:
model = to_object_model(object_type, obj_id) # type: ignore
security_manager.raise_for_ownership(model)
except SupersetSecurityException:
# skip the object if the user doesn't have access
skipped_tagged_objects.add((obj_type, obj_id))

self._properties["objects_to_tag"] = objects_to_tag - skipped_tagged_objects

if exceptions:
raise TagInvalidError(exceptions=exceptions)
15 changes: 5 additions & 10 deletions superset/tags/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@ def __init__(self, model_id: int, data: dict[str, Any]):
def run(self) -> Model:
self.validate()
if self._model:
if self._properties.get("objects_to_tag"):
# todo(hugh): can this manage duplication
TagDAO.create_tag_relationship(
objects_to_tag=self._properties["objects_to_tag"],
tag=self._model,
)
TagDAO.create_tag_relationship(
objects_to_tag=self._properties.get("objects_to_tag", []),
tag=self._model,
)
if description := self._properties.get("description"):
self._model.description = description
if tag_name := self._properties.get("name"):
Expand All @@ -63,11 +61,8 @@ def validate(self) -> None:

# Validate object_id
if objects_to_tag := self._properties.get("objects_to_tag"):
if any(obj_id == 0 for obj_type, obj_id in objects_to_tag):
exceptions.append(TagInvalidError(" invalid object_id"))

# Validate object type
for obj_type, obj_id in objects_to_tag:
for obj_type, _ in objects_to_tag:
object_type = to_object_type(obj_type)
if not object_type:
exceptions.append(
Expand Down
4 changes: 3 additions & 1 deletion superset/tags/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from marshmallow import fields, Schema
from marshmallow.validate import Range

from superset.dashboards.schemas import UserSchema

Expand Down Expand Up @@ -60,7 +61,8 @@ class TagObjectSchema(Schema):
name = fields.String()
description = fields.String(required=False, allow_none=True)
objects_to_tag = fields.List(
fields.Tuple((fields.String(), fields.Int())), required=False
fields.Tuple((fields.String(), fields.Int(validate=Range(min=1)))),
required=False,
)


Expand Down
26 changes: 15 additions & 11 deletions tests/unit_tests/tags/commands/create_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,35 +91,39 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture)
)


def test_create_command_failed_validate(
session_with_data: Session, mocker: MockFixture
):
def test_create_command_success_clear(session_with_data: Session, mocker: MockFixture):
from superset.connectors.sqla.models import SqlaTable
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import Query, SavedQuery
from superset.tags.commands.create import CreateCustomTagWithRelationshipsCommand
from superset.tags.commands.exceptions import TagInvalidError
from superset.tags.models import ObjectTypes, TaggedObject

# Define a list of objects to tag
query = session_with_data.query(SavedQuery).first()
chart = session_with_data.query(Slice).first()
dashboard = session_with_data.query(Dashboard).first()

mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=query)
mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=chart)
mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=query)

objects_to_tag = [
(ObjectTypes.query, query.id),
(ObjectTypes.chart, chart.id),
(ObjectTypes.dashboard, 0),
(ObjectTypes.dashboard, dashboard.id),
]

with pytest.raises(TagInvalidError):
CreateCustomTagWithRelationshipsCommand(
data={"name": "test_tag", "objects_to_tag": objects_to_tag}
).run()
CreateCustomTagWithRelationshipsCommand(
data={"name": "test_tag", "objects_to_tag": objects_to_tag}
).run()
assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)

CreateCustomTagWithRelationshipsCommand(
data={"name": "test_tag", "objects_to_tag": []}
).run()

assert len(session_with_data.query(TaggedObject).all()) == 0

0 comments on commit 6cc2b36

Please sign in to comment.