Skip to content

Commit

Permalink
Fix code issues
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Oct 25, 2020
1 parent 091330f commit bad72ee
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 34 deletions.
8 changes: 6 additions & 2 deletions tests/codegen/models/test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def test_dependencies(self):
],
choices=[
AttrChoiceFactory.create(
name="x", types=[AttrTypeFactory.create(qname="choiceAttr")]
name="x",
types=[
AttrTypeFactory.create(qname="choiceAttr"),
AttrTypeFactory.xs_string(),
],
),
AttrChoiceFactory.create(
name="x",
Expand Down Expand Up @@ -81,7 +85,7 @@ def test_dependencies(self):
"{http://www.w3.org/2001/XMLSchema}foobar",
"{xsdata}foo",
]
self.assertEqual(expected, list(obj.dependencies()))
self.assertCountEqual(expected, list(obj.dependencies()))

def test_property_has_suffix_attr(self):
obj = ClassFactory.create()
Expand Down
6 changes: 4 additions & 2 deletions tests/formats/dataclass/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
from xsdata.formats.dataclass.models.elements import XmlAttributes
from xsdata.formats.dataclass.models.elements import XmlElement
from xsdata.formats.dataclass.models.elements import XmlElements
from xsdata.formats.dataclass.models.elements import XmlMeta
from xsdata.formats.dataclass.models.elements import XmlText
from xsdata.formats.dataclass.models.elements import XmlVar
from xsdata.formats.dataclass.models.elements import XmlWildcard
from xsdata.models.enums import FormType


@dataclass
Expand Down Expand Up @@ -64,6 +62,10 @@ def test_find_choice(self):
var = XmlVar(name="foo", qname="foo")
self.assertIsNone(var.find_choice("foo"))

def test_find_choice_typed(self):
var = XmlVar(name="foo", qname="foo")
self.assertIsNone(var.find_choice_typed(int))


class XmlElementTests(TestCase):
def test_property_is_element(self):
Expand Down
40 changes: 15 additions & 25 deletions xsdata/codegen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,14 @@ def merge(self, source: "Restrictions"):
if source.sequential and (is_list or not self.is_list):
self.sequential = source.sequential

if source.choice:
self.choice = source.choice
self.choice = source.choice or self.choice
self.tokens = source.tokens or self.tokens

if not self.tokens and source.tokens:
self.tokens = True

# Update min occurs if current value is None and the new value is more than one.
# Update min occurs if current value is None or the new value is more than one.
if self.min_occurs is None or (min_occurs is not None and min_occurs != 1):
self.min_occurs = min_occurs

# Update max occurs if current value is None and the new value is more than one.
# Update max occurs if current value is None or the new value is more than one.
if self.max_occurs is None or (max_occurs is not None and max_occurs != 1):
self.max_occurs = max_occurs

Expand Down Expand Up @@ -190,8 +187,8 @@ class AttrType:
"""

qname: str
index: int = field(default_factory=int)
alias: Optional[str] = field(default=None)
index: int = field(default_factory=int, compare=False)
alias: Optional[str] = field(default=None, compare=False)
native: bool = field(default=False)
forward: bool = field(default=False)
circular: bool = field(default=False)
Expand Down Expand Up @@ -522,28 +519,21 @@ def dependencies(self) -> Iterator[str]:
Collect:
* base classes
* attribute types
* attribute choice types
* recursively go through the inner classes
* Ignore inner class references
* Ignore native types.
"""

seen = set()
types = {ext.type for ext in self.extensions}

for attr in self.attrs:
for attr_type in attr.types:
if attr_type.is_dependency and attr_type.name not in seen:
yield attr_type.qname
seen.add(attr_type.name)

for attr_choice in attr.choices:
for attr_type in attr_choice.types:
if attr_type.is_dependency and attr_type.name not in seen:
yield attr_type.qname
seen.add(attr_type.name)

for ext in self.extensions:
if ext.type.is_dependency and ext.type.name not in seen:
yield ext.type.qname
seen.add(ext.type.name)
types.update(attr.types)
types.update(tp for choice in attr.choices for tp in choice.types)

for tp in types:
if tp.is_dependency:
yield tp.qname

for inner in self.inner:
yield from inner.dependencies()
Expand Down
18 changes: 13 additions & 5 deletions xsdata/formats/dataclass/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from xsdata.formats.dataclass.models.elements import XmlMeta
from xsdata.formats.dataclass.models.elements import XmlVar
from xsdata.models.enums import NamespaceType
from xsdata.utils.collections import first
from xsdata.utils.constants import EMPTY_SEQUENCE
from xsdata.utils.namespaces import build_qname

Expand Down Expand Up @@ -150,8 +149,8 @@ def get_type_hints(self, clazz: Type, parent_ns: Optional[str]) -> Iterator[XmlV
xml_clazz = XmlType.to_xml_class(xml_type)
namespace = var.metadata.get("namespace")
namespaces = self.resolve_namespaces(xml_type, namespace, parent_ns)
first_namespace = first(x for x in namespaces if x and x[0] != "#")
qname = build_qname(first_namespace, local_name)
default_namespace = self.default_namespace(namespaces)
qname = build_qname(default_namespace, local_name)

choices = list(
self.build_choices(
Expand Down Expand Up @@ -190,12 +189,12 @@ def build_choices(
xml_type = choice.get("tag", XmlType.ELEMENT)
namespace = choice.get("namespace")
namespaces = self.resolve_namespaces(xml_type, namespace, parent_namespace)
first_namespace = first(x for x in namespaces if x and x[0] != "#")
default_namespace = self.default_namespace(namespaces)

types = self.real_types(_eval_type(choice["type"], globalns, None))
is_class = any(is_dataclass(clazz) for clazz in types)
xml_clazz = XmlType.to_xml_class(xml_type)
qname = build_qname(first_namespace, choice.get("name", "any"))
qname = build_qname(default_namespace, choice.get("name", "any"))

yield xml_clazz(
name=parent_name,
Expand Down Expand Up @@ -242,6 +241,15 @@ def resolve_namespaces(
result.add(ns)
return list(result)

@classmethod
def default_namespace(cls, namespaces: List[str]) -> Optional[str]:
"""Return the first valid namespace uri or None."""
for namespace in namespaces:
if namespace and not namespace.startswith("#"):
return namespace

return None

@classmethod
def default_value(cls, var: Field) -> Any:
"""Return the default value/factory for the given field."""
Expand Down

0 comments on commit bad72ee

Please sign in to comment.