diff --git a/tests/codegen/models/test_class.py b/tests/codegen/models/test_class.py index f1332aab8..fb89598b2 100644 --- a/tests/codegen/models/test_class.py +++ b/tests/codegen/models/test_class.py @@ -29,7 +29,11 @@ def test_dependencies(self): ], choices=[ AttrChoiceFactory.create( - name="x", types=[AttrTypeFactory.create(qname="choiceAttr")] + name="x", + types=[ + AttrTypeFactory.create(qname="choiceAttr"), + AttrTypeFactory.xs_string(), + ], ), AttrChoiceFactory.create( name="x", @@ -81,7 +85,7 @@ def test_dependencies(self): "{http://www.w3.org/2001/XMLSchema}foobar", "{xsdata}foo", ] - self.assertEqual(expected, list(obj.dependencies())) + self.assertCountEqual(expected, list(obj.dependencies())) def test_property_has_suffix_attr(self): obj = ClassFactory.create() diff --git a/tests/formats/dataclass/test_elements.py b/tests/formats/dataclass/test_elements.py index 7f14dfc0e..7a4cc9edd 100644 --- a/tests/formats/dataclass/test_elements.py +++ b/tests/formats/dataclass/test_elements.py @@ -9,11 +9,9 @@ from xsdata.formats.dataclass.models.elements import XmlAttributes from xsdata.formats.dataclass.models.elements import XmlElement from xsdata.formats.dataclass.models.elements import XmlElements -from xsdata.formats.dataclass.models.elements import XmlMeta from xsdata.formats.dataclass.models.elements import XmlText from xsdata.formats.dataclass.models.elements import XmlVar from xsdata.formats.dataclass.models.elements import XmlWildcard -from xsdata.models.enums import FormType @dataclass @@ -64,6 +62,10 @@ def test_find_choice(self): var = XmlVar(name="foo", qname="foo") self.assertIsNone(var.find_choice("foo")) + def test_find_choice_typed(self): + var = XmlVar(name="foo", qname="foo") + self.assertIsNone(var.find_choice_typed(int)) + class XmlElementTests(TestCase): def test_property_is_element(self): diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index 0b2459729..6627c47e9 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -127,17 +127,14 @@ def merge(self, source: "Restrictions"): if source.sequential and (is_list or not self.is_list): self.sequential = source.sequential - if source.choice: - self.choice = source.choice + self.choice = source.choice or self.choice + self.tokens = source.tokens or self.tokens - if not self.tokens and source.tokens: - self.tokens = True - - # Update min occurs if current value is None and the new value is more than one. + # Update min occurs if current value is None or the new value is more than one. if self.min_occurs is None or (min_occurs is not None and min_occurs != 1): self.min_occurs = min_occurs - # Update max occurs if current value is None and the new value is more than one. + # Update max occurs if current value is None or the new value is more than one. if self.max_occurs is None or (max_occurs is not None and max_occurs != 1): self.max_occurs = max_occurs @@ -190,8 +187,8 @@ class AttrType: """ qname: str - index: int = field(default_factory=int) - alias: Optional[str] = field(default=None) + index: int = field(default_factory=int, compare=False) + alias: Optional[str] = field(default=None, compare=False) native: bool = field(default=False) forward: bool = field(default=False) circular: bool = field(default=False) @@ -522,28 +519,21 @@ def dependencies(self) -> Iterator[str]: Collect: * base classes * attribute types + * attribute choice types * recursively go through the inner classes * Ignore inner class references * Ignore native types. """ - seen = set() + types = {ext.type for ext in self.extensions} + for attr in self.attrs: - for attr_type in attr.types: - if attr_type.is_dependency and attr_type.name not in seen: - yield attr_type.qname - seen.add(attr_type.name) - - for attr_choice in attr.choices: - for attr_type in attr_choice.types: - if attr_type.is_dependency and attr_type.name not in seen: - yield attr_type.qname - seen.add(attr_type.name) - - for ext in self.extensions: - if ext.type.is_dependency and ext.type.name not in seen: - yield ext.type.qname - seen.add(ext.type.name) + types.update(attr.types) + types.update(tp for choice in attr.choices for tp in choice.types) + + for tp in types: + if tp.is_dependency: + yield tp.qname for inner in self.inner: yield from inner.dependencies() diff --git a/xsdata/formats/dataclass/context.py b/xsdata/formats/dataclass/context.py index dbf469cad..6939a687a 100644 --- a/xsdata/formats/dataclass/context.py +++ b/xsdata/formats/dataclass/context.py @@ -22,7 +22,6 @@ from xsdata.formats.dataclass.models.elements import XmlMeta from xsdata.formats.dataclass.models.elements import XmlVar from xsdata.models.enums import NamespaceType -from xsdata.utils.collections import first from xsdata.utils.constants import EMPTY_SEQUENCE from xsdata.utils.namespaces import build_qname @@ -150,8 +149,8 @@ def get_type_hints(self, clazz: Type, parent_ns: Optional[str]) -> Iterator[XmlV xml_clazz = XmlType.to_xml_class(xml_type) namespace = var.metadata.get("namespace") namespaces = self.resolve_namespaces(xml_type, namespace, parent_ns) - first_namespace = first(x for x in namespaces if x and x[0] != "#") - qname = build_qname(first_namespace, local_name) + default_namespace = self.default_namespace(namespaces) + qname = build_qname(default_namespace, local_name) choices = list( self.build_choices( @@ -190,12 +189,12 @@ def build_choices( xml_type = choice.get("tag", XmlType.ELEMENT) namespace = choice.get("namespace") namespaces = self.resolve_namespaces(xml_type, namespace, parent_namespace) - first_namespace = first(x for x in namespaces if x and x[0] != "#") + default_namespace = self.default_namespace(namespaces) types = self.real_types(_eval_type(choice["type"], globalns, None)) is_class = any(is_dataclass(clazz) for clazz in types) xml_clazz = XmlType.to_xml_class(xml_type) - qname = build_qname(first_namespace, choice.get("name", "any")) + qname = build_qname(default_namespace, choice.get("name", "any")) yield xml_clazz( name=parent_name, @@ -242,6 +241,15 @@ def resolve_namespaces( result.add(ns) return list(result) + @classmethod + def default_namespace(cls, namespaces: List[str]) -> Optional[str]: + """Return the first valid namespace uri or None.""" + for namespace in namespaces: + if namespace and not namespace.startswith("#"): + return namespace + + return None + @classmethod def default_value(cls, var: Field) -> Any: """Return the default value/factory for the given field."""