Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax conditions on extensions flattening #754

Merged
merged 7 commits into from
Feb 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions tests/codegen/handlers/test_create_compound_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from xsdata.codegen.models import Restrictions
from xsdata.models.config import GeneratorConfig
from xsdata.models.enums import DataType
from xsdata.models.enums import Tag
from xsdata.utils.testing import AttrFactory
from xsdata.utils.testing import AttrTypeFactory
from xsdata.utils.testing import ClassFactory
Expand Down Expand Up @@ -123,12 +124,6 @@ def test_choose_name(self):
actual = self.processor.choose_name(target, ["a", "b", "c", "d"])
self.assertEqual("choice_1", actual)

base = ClassFactory.create()
base.attrs.append(AttrFactory.create(name="Choice!"))
target.extensions.append(ExtensionFactory.reference(base.qname))
self.container.extend((target, base))

target.attrs.clear()
actual = self.processor.choose_name(target, ["a", "b", "c", "d"])
self.assertEqual("choice_1", actual)

Expand All @@ -140,6 +135,39 @@ def test_choose_name(self):
actual = self.processor.choose_name(target, ["a", "b", "c"])
self.assertEqual("ThisOrThat", actual)

def test_build_reserved_names(self):
base = ClassFactory.create(
attrs=[
AttrFactory.create("standalone"),
AttrFactory.create(
name="first",
tag=Tag.CHOICE,
choices=[
AttrFactory.create(name="a"),
AttrFactory.create(name="b"),
AttrFactory.create(name="c"),
],
),
AttrFactory.create(
name="second",
tag=Tag.CHOICE,
choices=[
AttrFactory.create(name="b"),
AttrFactory.create(name="c"),
],
),
]
)

target = ClassFactory.create()
target.extensions.append(ExtensionFactory.reference(qname=base.qname))
self.processor.container.extend([base, target])

actual = self.processor.build_reserved_names(target, names=["b", "c"])
expected = {"standalone", "first"}

self.assertEqual(expected, actual)

def test_build_attr_choice(self):
attr = AttrFactory.create(
name="a", namespace="xsdata", default="123", help="help", fixed=True
Expand Down
16 changes: 0 additions & 16 deletions tests/codegen/handlers/test_flatten_class_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,22 +401,6 @@ def test_should_flatten_extension(self):
target.attrs = [target.attrs[1], target.attrs[0], target.attrs[2]]
self.assertTrue(self.processor.should_flatten_extension(source, target))

# Types violation
target = source.clone()
target.attrs[1].types = [
AttrTypeFactory.native(DataType.INT),
AttrTypeFactory.native(DataType.FLOAT),
]

source.attrs[1].types = [
AttrTypeFactory.native(DataType.INT),
AttrTypeFactory.native(DataType.FLOAT),
AttrTypeFactory.native(DataType.DECIMAL),
]
self.assertFalse(self.processor.should_flatten_extension(source, target))
target.attrs[1].types.append(AttrTypeFactory.native(DataType.QNAME))
self.assertTrue(self.processor.should_flatten_extension(source, target))

def test_replace_attributes_type(self):
extension = ExtensionFactory.create()
target = ClassFactory.elements(2)
Expand Down
22 changes: 22 additions & 0 deletions tests/codegen/handlers/test_validate_attributes_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from xsdata.codegen.handlers import ValidateAttributesOverrides
from xsdata.codegen.models import Status
from xsdata.models.config import GeneratorConfig
from xsdata.models.enums import DataType
from xsdata.models.enums import Tag
from xsdata.utils.testing import AttrFactory
from xsdata.utils.testing import ClassFactory
Expand Down Expand Up @@ -50,6 +51,18 @@ def test_process(self, mock_validate_override, mock_resolve_conflict):
class_a.attrs[1], class_c.attrs[1]
)

def test_overrides(self):
a = AttrFactory.create(tag=Tag.SIMPLE_TYPE)
b = a.clone()

self.assertTrue(self.processor.overrides(a, b))

b.tag = Tag.EXTENSION
self.assertTrue(self.processor.overrides(a, b))

b.namespace = "foo"
self.assertFalse(self.processor.overrides(a, b))

def test_validate_override(self):
attr_a = AttrFactory.create()
attr_b = attr_a.clone()
Expand Down Expand Up @@ -91,6 +104,7 @@ def test_validate_override(self):
self.processor.validate_override(target, attr_a, attr_b)
self.assertEqual(0, len(target.attrs))

# Source is list, parent is not
target.attrs.append(attr_a)
attr_a.restrictions.min_occurs = None
attr_a.restrictions.max_occurs = 10
Expand All @@ -99,6 +113,14 @@ def test_validate_override(self):
self.processor.validate_override(target, attr_a, attr_b)
self.assertEqual(sys.maxsize, attr_b.restrictions.max_occurs)

# Parent is any type, source isn't, skip
attr_a = AttrFactory.native(DataType.STRING)
attr_b = AttrFactory.native(DataType.ANY_SIMPLE_TYPE)
target = ClassFactory.create(attrs=[attr_a])

self.processor.validate_override(target, attr_a.clone(), attr_b)
self.assertEqual(attr_a, target.attrs[0])

def test_resolve_conflicts(self):
a = AttrFactory.create(name="foo", tag=Tag.ATTRIBUTE)
b = a.clone()
Expand Down
12 changes: 12 additions & 0 deletions tests/codegen/models/test_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ def test_property_is_nameless(self):
self.assertFalse(AttrFactory.create(tag=Tag.ATTRIBUTE).is_nameless)
self.assertTrue(AttrFactory.create(tag=Tag.ANY).is_nameless)

def test_property_is_any_type(self):
attr = AttrFactory.create(
types=[
AttrTypeFactory.create(qname="foo"),
AttrTypeFactory.native(DataType.FLOAT),
]
)
self.assertFalse(attr.is_any_type)

attr.types.append(AttrTypeFactory.native(DataType.ANY_SIMPLE_TYPE))
self.assertTrue(attr.is_any_type)

def test_property_native_types(self):
attr = AttrFactory.create(
types=[
Expand Down
11 changes: 11 additions & 0 deletions tests/models/xsd/test_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,14 @@ def test_property_bases(self):
obj.type = "foo"
obj.simple_type = None
self.assertEqual(["foo"], list(obj.bases))

def test_property_default_type(self):
obj = Attribute()
self.assertEqual("anySimpleType", obj.default_type)

obj = Attribute()
obj.ns_map["foo"] = Namespace.XS.uri
self.assertEqual("foo:anySimpleType", obj.default_type)

obj.fixed = "aa"
self.assertEqual("foo:string", obj.default_type)
20 changes: 16 additions & 4 deletions xsdata/codegen/handlers/create_compound_fields.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from collections import Counter
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Set

from xsdata.codegen.mixins import ContainerInterface
from xsdata.codegen.mixins import RelativeHandlerInterface
from xsdata.codegen.models import Attr
from xsdata.codegen.models import AttrType
from xsdata.codegen.models import Class
from xsdata.codegen.models import get_restriction_choice
from xsdata.codegen.models import get_slug
from xsdata.codegen.models import Restrictions
from xsdata.codegen.utils import ClassUtils
from xsdata.formats.dataclass.models.elements import XmlType
from xsdata.models.enums import DataType
from xsdata.models.enums import Tag
from xsdata.utils.collections import group_by
Expand Down Expand Up @@ -77,9 +79,6 @@ def group_fields(self, target: Class, attrs: List[Attr]):
)

def choose_name(self, target: Class, names: List[str]) -> str:
reserved = set(map(get_slug, self.base_attrs(target)))
reserved.update(map(get_slug, target.attrs))

if (
self.config.force_default_name
or len(names) > 3
Expand All @@ -89,8 +88,21 @@ def choose_name(self, target: Class, names: List[str]) -> str:
else:
name = "_Or_".join(names)

reserved = self.build_reserved_names(target, names)
return ClassUtils.unique_name(name, reserved)

def build_reserved_names(self, target: Class, names: List[str]) -> Set[str]:
names_counter = Counter(names)
all_attrs = self.base_attrs(target)
all_attrs.extend(target.attrs)

return {
attr.slug
for attr in all_attrs
if attr.xml_type != XmlType.ELEMENTS
or Counter([x.local_name for x in attr.choices]) != names_counter
}

@classmethod
def build_attr_choice(cls, attr: Attr) -> Attr:
"""
Expand Down
12 changes: 0 additions & 12 deletions xsdata/codegen/handlers/flatten_class_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,24 +213,12 @@ def should_flatten_extension(cls, source: Class, target: Class) -> bool:
source.is_simple_type
or target.has_suffix_attr
or (source.has_suffix_attr and target.attrs)
or not cls.validate_type_overrides(source, target)
or not cls.validate_sequence_order(source, target)
):
return True

return False

@classmethod
def validate_type_overrides(cls, source: Class, target: Class) -> bool:
"""Validate every override is using a subset of the parent attr
types."""
for attr in target.attrs:
src_attr = ClassUtils.find_attr(source, attr.name)
if src_attr and any(tp not in src_attr.types for tp in attr.types):
return False

return True

@classmethod
def validate_sequence_order(cls, source: Class, target: Class) -> bool:
"""
Expand Down
10 changes: 7 additions & 3 deletions xsdata/codegen/handlers/sanitize_attributes_default_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,14 @@ def should_reset_required(cls, attr: Attr) -> bool:
"""
Return whether the min occurrences for the attr needs to be reset.

Cases:
1. xs:any(Simple)Type, with no default value that's not a list already!
@Todo figure out if wildcards are supposed to be optional!
"""
return attr.default is None and object in attr.native_types and not attr.is_list
return (
not attr.is_attribute
and attr.default is None
and object in attr.native_types
and not attr.is_list
)

@classmethod
def should_reset_default(cls, attr: Attr) -> bool:
Expand Down
9 changes: 8 additions & 1 deletion xsdata/codegen/handlers/validate_attributes_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,24 @@ def process(self, target: Class):

if base_attrs:
base_attr = base_attrs[0]
if attr.tag == base_attr.tag:
if self.overrides(attr, base_attr):
self.validate_override(target, attr, base_attr)
else:
self.resolve_conflict(attr, base_attr)

@classmethod
def overrides(cls, a: Attr, b: Attr) -> bool:
return a.xml_type == b.xml_type and a.namespace == b.namespace

def base_attrs_map(self, target: Class) -> Dict[str, List[Attr]]:
base_attrs = self.base_attrs(target)
return collections.group_by(base_attrs, key=get_slug)

@classmethod
def validate_override(cls, target: Class, attr: Attr, source_attr: Attr):
if source_attr.is_any_type and not attr.is_any_type:
return

if attr.is_list and not source_attr.is_list:
# Hack much??? idk but Optional[str] can't override List[str]
source_attr.restrictions.max_occurs = sys.maxsize
Expand Down
18 changes: 11 additions & 7 deletions xsdata/codegen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,14 @@ def is_wildcard(self) -> bool:
xs:any."""
return self.tag in (Tag.ANY_ATTRIBUTE, Tag.ANY)

@property
def is_any_type(self) -> bool:
return any(tp is object for tp in self.get_native_types())

@property
def native_types(self) -> List[Type]:
"""Return a list of all builtin data types."""
result = set()
for tp in self.types:
datatype = tp.datatype
if datatype:
result.add(datatype.type)

return list(result)
return list(set(self.get_native_types()))

@property
def user_types(self) -> Iterator[AttrType]:
Expand All @@ -371,6 +369,12 @@ def clone(self) -> "Attr":
restrictions=self.restrictions.clone(),
)

def get_native_types(self) -> Iterator[Type]:
for tp in self.types:
datatype = tp.datatype
if datatype:
yield datatype.type


@dataclass(unsafe_hash=True)
class Extension:
Expand Down
5 changes: 5 additions & 0 deletions xsdata/models/xsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ def attr_types(self) -> Iterator[str]:
elif self.ref:
yield self.ref

@property
def default_type(self) -> str:
datatype = DataType.STRING if self.fixed else DataType.ANY_SIMPLE_TYPE
return datatype.prefixed(self.xs_prefix)

def get_restrictions(self) -> Dict[str, Anything]:
if self.use == UseType.REQUIRED:
restrictions = {"min_occurs": 1, "max_occurs": 1}
Expand Down