diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ed9057d1..84ba2a43 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,9 +3,11 @@ Changelog (Unreleased) ~~~~~~~~~~~~ +* Add a management command (``remove_orphaned_tags``) to remove orphaned tags +* Add a fallback for when multiple tags are found in case-insensitivity mode (the earliest by PK is returned) +* Add a ``deduplicate_tags`` management command to remove duplicate tags based on case insensitivity. This feature is enabled when ``TAGGIT_CASE_INSENSITIVE`` is set to ``True`` in the settings. * We no longer package tests, docs, or the sample taggit app into the distributed wheels. While we believe this shouldn't affect anything for users of the library, please tell us if you find yourself hitting issues (like around import errors) -* Added a management command (``remove_orphaned_tags``) to remove orphaned tags 6.0.0 (2024-07-27) ~~~~~~~~~~~~~~~~~~ diff --git a/taggit/management/commands/deduplicate_tags.py b/taggit/management/commands/deduplicate_tags.py new file mode 100644 index 00000000..a745634b --- /dev/null +++ b/taggit/management/commands/deduplicate_tags.py @@ -0,0 +1,83 @@ +from django.conf import settings +from django.core.management.base import BaseCommand +from django.db import transaction + +from taggit.models import Tag, TaggedItem + + +class Command(BaseCommand): + help = "Identify and remove duplicate tags based on case insensitivity" + + def handle(self, *args, **kwargs): + if not getattr(settings, "TAGGIT_CASE_INSENSITIVE", False): + self.stdout.write( + self.style.ERROR("TAGGIT_CASE_INSENSITIVE is not enabled.") + ) + return + + tags = Tag.objects.all() + tag_dict = {} + + for tag in tags: + lower_name = tag.name.lower() + if lower_name in tag_dict: + existing_tag = tag_dict[lower_name] + self._deduplicate_tags(existing_tag=existing_tag, tag_to_remove=tag) + else: + tag_dict[lower_name] = tag + + self.stdout.write(self.style.SUCCESS("Tag deduplication complete.")) + + @transaction.atomic + def _deduplicate_tags(self, existing_tag, tag_to_remove): + """ + Remove a tag by merging it into an existing tag + """ + # If this ends up very slow for you, please file a ticket! + # This isn't trying to be performant, in order to keep the code simple. + for item in TaggedItem.objects.filter(tag=tag_to_remove): + # if we already have the same association on the model + # (via the existing tag), then we can just remove the + # tagged item. + tag_exists_other = TaggedItem.objects.filter( + tag=existing_tag, + content_type_id=item.content_type_id, + object_id=item.object_id, + ).exists() + if tag_exists_other: + item.delete() + else: + item.tag = existing_tag + item.save() + + # this should never trigger, but can never be too sure + assert not TaggedItem.objects.filter( + tag=tag_to_remove + ).exists(), "Tags were not all cleaned up!" + + tag_to_remove.delete() + + def _collect_tagged_items(self, tag, existing_tag, tagged_items_to_update): + for item in TaggedItem.objects.filter(tag=tag): + tagged_items_to_update[(item.content_type_id, item.object_id)].append( + existing_tag.id + ) + + def _remove_duplicates_and_update(self, tagged_items_to_update): + with transaction.atomic(): + for (content_type_id, object_id), tag_ids in tagged_items_to_update.items(): + unique_tag_ids = set(tag_ids) + if len(unique_tag_ids) > 1: + first_tag_id = unique_tag_ids.pop() + for duplicate_tag_id in unique_tag_ids: + TaggedItem.objects.filter( + content_type_id=content_type_id, + object_id=object_id, + tag_id=duplicate_tag_id, + ).delete() + + TaggedItem.objects.filter( + content_type_id=content_type_id, + object_id=object_id, + tag_id=first_tag_id, + ).update(tag_id=first_tag_id) diff --git a/taggit/managers.py b/taggit/managers.py index 0ff8e976..cf8e3407 100644 --- a/taggit/managers.py +++ b/taggit/managers.py @@ -4,6 +4,7 @@ from django.conf import settings from django.contrib.contenttypes.fields import GenericRelation from django.contrib.contenttypes.models import ContentType +from django.core.exceptions import MultipleObjectsReturned from django.db import connections, models, router from django.db.models import signals from django.db.models.fields.related import ( @@ -242,6 +243,13 @@ def _to_tag_model_instances(self, tags, tag_kwargs): existing_tags_for_str[name] = tag except self.through.tag_model().DoesNotExist: tags_to_create.append(name) + except MultipleObjectsReturned: + tag = ( + manager.filter(name__iexact=name, **tag_kwargs) + .order_by("pk") + .first() + ) + existing_tags_for_str[name] = tag else: # Django is smart enough to not actually query if tag_strs is empty # but importantly, this is a single query for all potential tags diff --git a/tests/test_deduplicate_tags.py b/tests/test_deduplicate_tags.py new file mode 100644 index 00000000..b7aa782f --- /dev/null +++ b/tests/test_deduplicate_tags.py @@ -0,0 +1,69 @@ +from io import StringIO + +from django.conf import settings +from django.core.management import call_command +from django.test import TestCase + +from taggit.models import Tag, TaggedItem +from tests.models import Food, HousePet + + +class DeduplicateTagsTests(TestCase): + def setUp(self): + settings.TAGGIT_CASE_INSENSITIVE = True + + self.tag1 = Tag.objects.create(name="Python") + self.tag2 = Tag.objects.create(name="python") + self.tag3 = Tag.objects.create(name="PYTHON") + + self.food_item = Food.objects.create(name="Apple") + self.pet_item = HousePet.objects.create(name="Fido") + + self.food_item.tags.add(self.tag1) + self.pet_item.tags.add(self.tag2) + self.pet_item.tags.add(self.tag3) + + def test_deduplicate_tags(self): + self.assertEqual(Tag.objects.count(), 3) + self.assertEqual(TaggedItem.objects.count(), 3) + + out = StringIO() + call_command("deduplicate_tags", stdout=out) + + self.assertEqual(Tag.objects.count(), 1) + self.assertEqual(TaggedItem.objects.count(), 2) + + self.assertTrue(Tag.objects.filter(name__iexact="python").exists()) + self.assertEqual( + TaggedItem.objects.filter(tag__name__iexact="python").count(), 2 + ) + + self.assertIn("Tag deduplication complete.", out.getvalue()) + + def test_no_duplicates(self): + self.assertEqual(Tag.objects.count(), 3) + self.assertEqual(TaggedItem.objects.count(), 3) + + out = StringIO() + call_command("deduplicate_tags", stdout=out) + + self.assertEqual(Tag.objects.count(), 1) + self.assertEqual(TaggedItem.objects.count(), 2) + + self.assertTrue(Tag.objects.filter(name__iexact="python").exists()) + self.assertEqual( + TaggedItem.objects.filter(tag__name__iexact="python").count(), 2 + ) + + self.assertIn("Tag deduplication complete.", out.getvalue()) + + def test_taggit_case_insensitive_not_enabled(self): + settings.TAGGIT_CASE_INSENSITIVE = False + + out = StringIO() + call_command("deduplicate_tags", stdout=out) + + self.assertIn("TAGGIT_CASE_INSENSITIVE is not enabled.", out.getvalue()) + + self.assertEqual(Tag.objects.count(), 3) + self.assertEqual(TaggedItem.objects.count(), 3)