Skip to content

Commit

Permalink
Fix #194 AttributeGroup handler search for group sources only
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Jun 8, 2020
1 parent 3c90214 commit 5a8f7fd
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 94 deletions.
56 changes: 35 additions & 21 deletions tests/codegen/handlers/test_attribute_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from xsdata.codegen.container import ClassContainer
from xsdata.codegen.handlers import AttributeGroupHandler
from xsdata.codegen.models import Attr
from xsdata.codegen.models import Status
from xsdata.codegen.utils import ClassUtils
from xsdata.exceptions import AnalyzerValueError
from xsdata.models.xsd import AttributeGroup
from xsdata.models.xsd import ComplexType
from xsdata.models.xsd import Group


class AttributeGroupHandlerTests(FactoryTestCase):
Expand Down Expand Up @@ -40,37 +44,47 @@ def test_process(self, mock_process_attribute, mock_is_group):
[mock.call(target, target.attrs[1]), mock.call(target, target.attrs[0]),]
)

@mock.patch.object(ClassUtils, "clone_attribute")
@mock.patch.object(ClassContainer, "find")
def test_process_attribute(self, mock_container_find, mock_clone_attribute):
source = ClassFactory.elements(2)
group_attr = AttrFactory.attribute_group(name="foo:bar")
@mock.patch.object(ClassUtils, "copy_group_attributes")
def test_process_attribute_with_group(self, mock_copy_group_attributes):
complex_bar = ClassFactory.create(type=ComplexType, name="bar")
group_bar = ClassFactory.create(type=Group, name="bar")
group_attr = AttrFactory.attribute_group(name="bar")
target = ClassFactory.create()
target.attrs.append(group_attr)

mock_container_find.return_value = source
mock_clone_attribute.side_effect = lambda x, y, z: x.clone()
self.processor.container.add(complex_bar)
self.processor.container.add(group_bar)
self.processor.container.add(target)

self.processor.process_attribute(target, group_attr)
mock_copy_group_attributes.assert_called_once_with(
group_bar, target, group_attr
)

@mock.patch.object(ClassUtils, "copy_group_attributes")
def test_process_attribute_with_attribute_group(self, mock_copy_group_attributes):
complex_bar = ClassFactory.create(type=ComplexType, name="bar")
group_bar = ClassFactory.create(type=AttributeGroup, name="bar")
group_attr = AttrFactory.attribute_group(name="bar")
target = ClassFactory.create()
target.attrs.append(group_attr)

self.assertEqual(2, len(target.attrs))
self.assertIsNot(source.attrs[0], target.attrs[0])
self.assertIsNot(source.attrs[1], target.attrs[1])
self.assertNotIn(group_attr, target.attrs)
self.processor.container.add(complex_bar)
self.processor.container.add(group_bar)
self.processor.container.add(target)

mock_clone_attribute.assert_has_calls(
[
mock.call(source.attrs[0], group_attr.restrictions, "foo"),
mock.call(source.attrs[1], group_attr.restrictions, "foo"),
]
self.processor.process_attribute(target, group_attr)
mock_copy_group_attributes.assert_called_once_with(
group_bar, target, group_attr
)

@mock.patch.object(ClassContainer, "find")
def test_process_attribute_with_circular_reference(self, mock_container_find):
group_attr = AttrFactory.attribute_group(name="foo:bar")
target = ClassFactory.create()
def test_process_attribute_with_circular_reference(self):
group_attr = AttrFactory.attribute_group(name="bar")
target = ClassFactory.create(name="bar", type=Group)
target.attrs.append(group_attr)
mock_container_find.return_value = target

target.status = Status.PROCESSING
self.processor.container.add(target)

self.processor.process_attribute(target, group_attr)
self.assertFalse(group_attr in target.attrs)
Expand Down
101 changes: 61 additions & 40 deletions tests/codegen/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from xsdata.codegen.container import ClassContainer
from xsdata.codegen.utils import ClassUtils
from xsdata.codegen.validator import ClassValidator
from xsdata.models.enums import Tag
from xsdata.models.xsd import ComplexType
from xsdata.models.xsd import Element
from xsdata.models.xsd import SimpleType
Expand All @@ -21,12 +22,12 @@ def setUp(self):
self.validator = ClassValidator(container=container)

@mock.patch.object(ClassValidator, "update_abstract_classes")
@mock.patch.object(ClassValidator, "merge_redefined_classes")
@mock.patch.object(ClassValidator, "handle_duplicate_types")
@mock.patch.object(ClassValidator, "remove_invalid_classes")
def test_handle_duplicate_classes(
def test_process(
self,
mock_remove_invalid_classes,
mock_merge_redefined_classes,
mock_handle_duplicate_types,
mock_update_abstract_classes,
):
first = ClassFactory.create()
Expand All @@ -37,7 +38,7 @@ def test_handle_duplicate_classes(
self.validator.process()

mock_remove_invalid_classes.assert_called_once_with([first, second])
mock_merge_redefined_classes.assert_called_once_with([first, second])
mock_handle_duplicate_types.assert_called_once_with([first, second])
mock_update_abstract_classes.assert_called_once_with([first, second])

def test_remove_invalid_classes(self):
Expand All @@ -58,6 +59,45 @@ def test_remove_invalid_classes(self):
self.validator.remove_invalid_classes(classes)
self.assertEqual([second, third], classes)

@mock.patch.object(ClassValidator, "select_winner")
def test_handle_duplicate_types(self, mock_select_winner):

one = ClassFactory.create()
two = one.clone()
three = one.clone()
four = ClassFactory.create()

mock_select_winner.return_value = 0

classes = [one, two, three, four]

self.validator.handle_duplicate_types(classes)
self.assertEqual([one, four], classes)
mock_select_winner.assert_called_once_with([one, two, three])

@mock.patch.object(ClassValidator, "merge_redefined_type")
@mock.patch.object(ClassValidator, "select_winner")
def test_handle_duplicate_types_with_redefined_type(
self, mock_select_winner, mock_merge_redefined_type
):

one = ClassFactory.create()
two = one.clone()
three = one.clone()
four = ClassFactory.create()

mock_select_winner.return_value = 0
one.container = Tag.REDEFINE

classes = [one, two, three, four]

self.validator.handle_duplicate_types(classes)
self.assertEqual([one, four], classes)
mock_select_winner.assert_called_once_with([one, two, three])
mock_merge_redefined_type.assert_has_calls(
[mock.call(two, one), mock.call(three, one),]
)

def test_update_abstract_classes(self):
one = ClassFactory.create(name="foo", abstract=True, type=Element)
two = ClassFactory.create(name="foo", type=Element)
Expand All @@ -71,52 +111,33 @@ def test_update_abstract_classes(self):
self.assertTrue(three.abstract) # Marked as abstract
self.assertFalse(four.abstract) # Is common

def test_merge_redefined_classes_selects_last_defined_class(self):
class_a = ClassFactory.create()
class_b = ClassFactory.create()
class_c = class_a.clone()
classes = [class_a, class_b, class_c]

self.validator.merge_redefined_classes(classes)
self.assertEqual(2, len(classes))
self.assertIn(class_b, classes)
self.assertIn(class_c, classes)

@mock.patch.object(ClassUtils, "copy_extensions")
@mock.patch.object(ClassUtils, "copy_attributes")
def test_merge_redefined_classes_with_circular_extension(
def test_merge_redefined_type_with_circular_extension(
self, mock_copy_attributes, mock_copy_extensions
):
class_a = ClassFactory.create()
class_b = ClassFactory.create()
class_c = class_a.clone()
source = ClassFactory.create()
target = source.clone()

ext_a = ExtensionFactory.create(type=AttrTypeFactory.create(name=class_a.name))
ext_a = ExtensionFactory.create(type=AttrTypeFactory.create(name=source.name))
ext_str = ExtensionFactory.create(type=AttrTypeFactory.create(name="foo"))
class_c.extensions.append(ext_str)
class_c.extensions.append(ext_a)
classes = [class_a, class_b, class_c]
target.extensions.append(ext_str)
target.extensions.append(ext_a)

self.validator.merge_redefined_classes(classes)
self.assertEqual(2, len(classes))
self.validator.merge_redefined_type(source, target)

mock_copy_attributes.assert_called_once_with(class_a, class_c, ext_a)
mock_copy_extensions.assert_called_once_with(class_a, class_c, ext_a)
mock_copy_attributes.assert_called_once_with(source, target, ext_a)
mock_copy_extensions.assert_called_once_with(source, target, ext_a)

@mock.patch.object(ClassUtils, "copy_group_attributes")
def test_merge_redefined_classes_with_circular_group(
self, mock_copy_group_attributes
):
class_a = ClassFactory.create()
class_c = class_a.clone()
def test_merge_redefined_type_with_circular_group(self, mock_copy_group_attributes):
source = ClassFactory.create()
target = source.clone()
target.container = Tag.REDEFINE
first_attr = AttrFactory.create()
second_attr = AttrFactory.create(name=class_a.name)
class_c.attrs.extend((first_attr, second_attr))
second_attr = AttrFactory.create(name=source.name)
target.attrs.extend((first_attr, second_attr))

classes = [class_a, class_c]
self.validator.merge_redefined_classes(classes)
self.assertEqual(1, len(classes))
self.validator.merge_redefined_type(source, target)

mock_copy_group_attributes.assert_called_once_with(
class_a, class_c, second_attr
)
mock_copy_group_attributes.assert_called_once_with(source, target, second_attr)
20 changes: 6 additions & 14 deletions xsdata/codegen/handlers/attribute_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from xsdata.codegen.models import Class
from xsdata.codegen.utils import ClassUtils
from xsdata.exceptions import AnalyzerValueError
from xsdata.utils import text
from xsdata.models.xsd import AttributeGroup
from xsdata.models.xsd import Group


@dataclass
Expand Down Expand Up @@ -37,23 +38,14 @@ def process_attribute(self, target: Class, attr: Attr):
attribute.
"""
qname = target.source_qname(attr.name)
source = self.container.find(qname)
source = self.container.find(
qname, condition=lambda x: x.type in (AttributeGroup, Group)
)

if not source:
raise AnalyzerValueError(f"Group attribute not found: `{qname}`")

if source is target:
target.attrs.remove(attr)
else:
index = target.attrs.index(attr)
target.attrs.pop(index)
prefix = text.prefix(attr.name)

for source_attr in source.attrs:
clone = ClassUtils.clone_attribute(
source_attr, attr.restrictions, prefix
)
target.attrs.insert(index, clone)
index += 1

ClassUtils.copy_inner_classes(source, target)
ClassUtils.copy_group_attributes(source, target, attr)
59 changes: 40 additions & 19 deletions xsdata/codegen/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def process(self):
Steps:
1. Remove classes with missing extension type.
2. Merge redefined classes.
2. Handle duplicate types.
3. Fix implied abstract flags.
"""
for classes in self.container.values():
Expand All @@ -34,7 +34,7 @@ def process(self):
self.remove_invalid_classes(classes)

if len(classes) > 1:
self.merge_redefined_classes(classes)
self.handle_duplicate_types(classes)

if len(classes) > 1:
self.update_abstract_classes(classes)
Expand All @@ -56,38 +56,59 @@ def is_invalid(source: Class, ext: Extension) -> bool:
classes.remove(target)

@classmethod
def merge_redefined_classes(cls, classes: List[Class]):
"""Merge original and redefined classes."""
def handle_duplicate_types(cls, classes: List[Class]):
"""Handle classes with same namespace, name that are derived from the
same xs type."""

grouped = group_by(classes, lambda x: f"{x.type.__name__}{x.source_qname()}")
for items in grouped.values():
if len(items) == 1:
continue

index = next(
(
index
for index, item in enumerate(items)
if item.container in (Tag.OVERRIDE, Tag.REDEFINE)
),
-1,
)
index = cls.select_winner(list(items))
winner = items.pop(index)

for item in items:
classes.remove(item)

if winner.container == Tag.REDEFINE:
cls.merge_redefined_type(item, winner)

circular_extension = cls.find_circular_extension(winner)
circular_group = cls.find_circular_group(winner)
@classmethod
def merge_redefined_type(cls, source: Class, target: Class):
"""
Copy any attributes and extensions to redefined types from the original
definitions.
Redefined inheritance is optional search for self references in
extensions and attribute groups.
"""
circular_extension = cls.find_circular_extension(target)
circular_group = cls.find_circular_group(target)

if circular_extension:
ClassUtils.copy_attributes(source, target, circular_extension)
ClassUtils.copy_extensions(source, target, circular_extension)

if circular_extension:
ClassUtils.copy_attributes(item, winner, circular_extension)
ClassUtils.copy_extensions(item, winner, circular_extension)
if circular_group:
ClassUtils.copy_group_attributes(source, target, circular_group)

if circular_group:
ClassUtils.copy_group_attributes(item, winner, circular_group)
@classmethod
def select_winner(cls, candidates: List[Class]) -> int:
"""
Returns the index of the class that will survive the duplicate process.
Classes that were extracted from in xs:override/xs:redefined
containers have priority, otherwise pick the last in the list.
"""
return next(
(
index
for index, item in enumerate(candidates)
if item.container in (Tag.OVERRIDE, Tag.REDEFINE)
),
-1,
)

@classmethod
def find_circular_extension(cls, target: Class) -> Optional[Extension]:
Expand Down

0 comments on commit 5a8f7fd

Please sign in to comment.