Skip to content

Commit

Permalink
Refine and test the find_assigned_other_versions #12
Browse files Browse the repository at this point in the history
Signed-off-by: tdruez <tdruez@nexb.com>
  • Loading branch information
tdruez committed May 24, 2024
1 parent 81561dc commit 33fe828
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 33 deletions.
6 changes: 5 additions & 1 deletion component_catalog/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,12 +1915,16 @@ def get_export_cyclonedx_url(self):
return self.get_url("export_cyclonedx")

@classmethod
def get_identifier_fields(cls):
def get_identifier_fields(cls, *args, purl_fields_only=False, **kwargs):
"""
Explicit list of identifier fields as we do not enforce a unique together
on this model.
This is used in the Importer, to catch duplicate entries.
The purl_fields_only option can be use to limit the results.
"""
if purl_fields_only:
return PACKAGE_URL_FIELDS

return ["filename", "download_url", *PACKAGE_URL_FIELDS]

@property
Expand Down
3 changes: 3 additions & 0 deletions component_catalog/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from component_catalog.importers import ComponentImporter
from component_catalog.importers import PackageImporter
from component_catalog.models import PACKAGE_URL_FIELDS
from component_catalog.models import Component
from component_catalog.models import ComponentAssignedLicense
from component_catalog.models import ComponentAssignedPackage
Expand Down Expand Up @@ -369,6 +370,8 @@ def test_component_catalog_models_get_identifier_fields(self):
for model_class, expected in inputs:
self.assertEqual(expected, model_class.get_identifier_fields())

self.assertEqual(PACKAGE_URL_FIELDS, Package.get_identifier_fields(purl_fields_only=True))

def test_component_model_get_absolute_url(self):
c = Component(name="c1", version="1.0", dataspace=self.dataspace)
self.assertEqual("/components/nexB/c1/1.0/", c.get_absolute_url())
Expand Down
2 changes: 1 addition & 1 deletion dje/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def _get_local_foreign_fields(self):
local_foreign_fields = property(_get_local_foreign_fields)

@classmethod
def get_identifier_fields(cls):
def get_identifier_fields(cls, *args, **kwargs):
"""
Return a list of the fields, based on the Meta unique_together, to be
used to match a unique instance within a Dataspace.
Expand Down
2 changes: 1 addition & 1 deletion policy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def str_with_content_type(self):
return f"{self.label} ({self.content_type.model})"

@classmethod
def get_identifier_fields(cls):
def get_identifier_fields(cls, *args, **kwargs):
"""Hack required by the Component import."""
return ["label"]

Expand Down
27 changes: 11 additions & 16 deletions product_portfolio/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
#

import uuid
from contextlib import suppress
from pathlib import Path

from django.conf import settings
from django.core.exceptions import MultipleObjectsReturned
from django.core.exceptions import ObjectDoesNotExist
from django.core.exceptions import ValidationError
from django.db import models
from django.utils.functional import cached_property
Expand Down Expand Up @@ -348,23 +345,18 @@ def get_relationship_model(self, obj):

return relationship_model

def find_assigned_other_version(self, obj):
def find_assigned_other_versions(self, obj):
"""
Look for the same object with a different version already assigned to the product.
Return:
------
relation: The relationship for the object with a different version, if found.
None: If no such relation exists or if multiple versions are found.
Look for the same objects with a different version already assigned to the product.
Return the relation queryset for the objects with a different version.
"""
object_model_name = obj._meta.model_name # "component" or "package"
relationship_model = self.get_relationship_model(obj)

# Craft the lookups excluding the version field
no_version_object_lookups = {
f"{object_model_name}__{field_name}": getattr(obj, field_name)
for field_name in obj.get_identifier_fields()
for field_name in obj.get_identifier_fields(purl_fields_only=True)
if field_name != "version"
}

Expand All @@ -373,9 +365,11 @@ def find_assigned_other_version(self, obj):
"dataspace": obj.dataspace,
**no_version_object_lookups,
}
excludes = {
f"{object_model_name}__id": obj.id,
}

with suppress(ObjectDoesNotExist, MultipleObjectsReturned):
return relationship_model.objects.get(**filters)
return relationship_model.objects.exclude(**excludes).filter(**filters)

def assign_object(self, obj, user, replace_existing_version=False):
"""
Expand All @@ -399,8 +393,9 @@ def assign_object(self, obj, user, replace_existing_version=False):
# 2. Find an existing relation for another version of the same object
# when replace_existing_version is provided.
if replace_existing_version:
# other_version = self.find_assigned_other_version(obj)
pass
other_versions = self.find_assigned_other_versions(obj)
if len(other_versions) == 1:
print("Remplace the obj on the relationship with:", other_versions[0])

# 3. Otherwise, create a new relation
defaults = {
Expand Down
68 changes: 54 additions & 14 deletions product_portfolio/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,28 +285,68 @@ def test_product_model_assign_objects(self):
)
self.assertEqual(self.super_user, self.product1.last_modified_by)

def test_product_model_find_assigned_other_version_component(self):
def test_product_model_find_assigned_other_versions_component(self):
component1 = Component.objects.create(name="c", version="1.0", dataspace=self.dataspace)
component2 = Component.objects.create(name="c", version="2.0", dataspace=self.dataspace)
component3 = Component.objects.create(name="c", version="3.0", dataspace=self.dataspace)

# No other version assigned
self.assertIsNone(self.product1.find_assigned_other_version(component1))
self.assertIsNone(self.product1.find_assigned_other_version(component2))
self.assertIsNone(self.product1.find_assigned_other_version(component3))
self.assertQuerySetEqual([], self.product1.find_assigned_other_versions(component1))
self.assertQuerySetEqual([], self.product1.find_assigned_other_versions(component2))
self.assertQuerySetEqual([], self.product1.find_assigned_other_versions(component3))

# 1 other version assigned
p1_c1 = self.product1.assign_object(component1, self.super_user)
# TODO: Do not include self?
# self.assertIsNone(self.product1.find_assigned_other_version(component1))
self.assertEqual(p1_c1, self.product1.find_assigned_other_version(component2))
self.assertEqual(p1_c1, self.product1.find_assigned_other_version(component3))

# Multiple other version assigned
self.product1.assign_object(component2, self.super_user)
self.assertIsNone(self.product1.find_assigned_other_version(component1))
self.assertIsNone(self.product1.find_assigned_other_version(component2))
self.assertIsNone(self.product1.find_assigned_other_version(component3))
self.assertQuerySetEqual([], self.product1.find_assigned_other_versions(component1))
self.assertQuerySetEqual([p1_c1], self.product1.find_assigned_other_versions(component2))
self.assertQuerySetEqual([p1_c1], self.product1.find_assigned_other_versions(component3))

# 2 other versions assigned
p1_c2 = self.product1.assign_object(component2, self.super_user)
self.assertQuerySetEqual([p1_c2], self.product1.find_assigned_other_versions(component1))
self.assertQuerySetEqual([p1_c1], self.product1.find_assigned_other_versions(component2))
self.assertQuerySetEqual(
[p1_c1, p1_c2], self.product1.find_assigned_other_versions(component3)
)

def test_product_model_find_assigned_other_versions_package(self):
package_data = {
"filename": "package.zip",
"type": "deb",
"namespace": "debian",
"name": "curl",
"dataspace": self.dataspace,
}
package1 = Package.objects.create(**package_data, version="1.0")
package2 = Package.objects.create(**package_data, version="2.0")
package3 = Package.objects.create(**package_data, version="3.0")

# No other version assigned
self.assertQuerySetEqual([], self.product1.find_assigned_other_versions(package1))
self.assertQuerySetEqual([], self.product1.find_assigned_other_versions(package2))
self.assertQuerySetEqual([], self.product1.find_assigned_other_versions(package3))

# 1 other version assigned
p1_p1 = self.product1.assign_object(package1, self.super_user)
self.assertQuerySetEqual([], self.product1.find_assigned_other_versions(package1))
self.assertQuerySetEqual([p1_p1], self.product1.find_assigned_other_versions(package2))
self.assertQuerySetEqual([p1_p1], self.product1.find_assigned_other_versions(package3))

# 2 other versions assigned
p1_p2 = self.product1.assign_object(package2, self.super_user)
self.assertQuerySetEqual([p1_p2], self.product1.find_assigned_other_versions(package1))
self.assertQuerySetEqual([p1_p1], self.product1.find_assigned_other_versions(package2))
self.assertQuerySetEqual(
[p1_p1, p1_p2], self.product1.find_assigned_other_versions(package3)
)

# Only PURL fields are used as lookups as the filename and download_url
# fields change between version.
package_data["filename"] = "different_filename"
package4 = Package.objects.create(**package_data, version="4.0")
self.assertQuerySetEqual(
[p1_p1, p1_p2], self.product1.find_assigned_other_versions(package4)
)

def test_product_model_field_changes_mixin(self):
self.assertFalse(Product().has_changed("name"))
Expand Down

0 comments on commit 33fe828

Please sign in to comment.