diff --git a/tests/codegen/handlers/test_add_attribute_substitutions.py b/tests/codegen/handlers/test_add_attribute_substitutions.py index 3a12191ca..b38f80aa4 100644 --- a/tests/codegen/handlers/test_add_attribute_substitutions.py +++ b/tests/codegen/handlers/test_add_attribute_substitutions.py @@ -121,3 +121,19 @@ def test_create_substitution(self): ) self.assertEqual(expected, actual) + + def test_prepare_substituted(self): + attr = AttrFactory.create() + attr.restrictions.min_occurs = 1 + attr.restrictions.path.append(("s", 0, 1, 1)) + + self.processor.prepare_substituted(attr) + + self.assertEqual(0, attr.restrictions.min_occurs) + self.assertEqual(id(attr), attr.restrictions.choice) + self.assertEqual(2, len(attr.restrictions.path)) + self.assertEqual(("c", id(attr), 1, 1), attr.restrictions.path[-1]) + + attr.restrictions.choice = 1 + self.processor.prepare_substituted(attr) + self.assertEqual(("c", id(attr), 1, 1), attr.restrictions.path[-1]) diff --git a/tests/codegen/handlers/test_calculate_attribute_paths.py b/tests/codegen/handlers/test_calculate_attribute_paths.py index 828111bd7..97a6f3945 100644 --- a/tests/codegen/handlers/test_calculate_attribute_paths.py +++ b/tests/codegen/handlers/test_calculate_attribute_paths.py @@ -47,14 +47,14 @@ def test_process(self): restrictions=Restrictions( min_occurs=1, max_occurs=1, - path=[("s", 1, 1, 1), ("c", 4, 0, 1)], + path=[("s", 1, 1, 1), ("c", 4, 0, 1), ("c", 100, 1, 1)], ) ), AttrFactory.element( restrictions=Restrictions( min_occurs=1, max_occurs=1, - path=[("s", 1, 1, 1), ("c", 4, 0, 1)], + path=[("s", 1, 1, 1), ("c", 4, 0, 1), ("a", 101, 1, 1)], ) ), AttrFactory.element( diff --git a/tests/formats/dataclass/parsers/nodes/test_element.py b/tests/formats/dataclass/parsers/nodes/test_element.py index e8aa50f4b..37eaa936e 100644 --- a/tests/formats/dataclass/parsers/nodes/test_element.py +++ b/tests/formats/dataclass/parsers/nodes/test_element.py @@ -401,6 +401,35 @@ def test_build_node_with_dataclass_var(self, mock_ctx_fetch, mock_xsi_type): mock_xsi_type.assert_called_once_with(attrs, ns_map) mock_ctx_fetch.assert_called_once_with(var.clazz, namespace, xsi_type) + @mock.patch.object(ParserUtils, "xsi_type", return_value="foo") + @mock.patch.object(XmlContext, "fetch") + def test_build_node_with_dataclass_var_and_mismatch_xsi_type( + self, mock_ctx_fetch, mock_xsi_type + ): + var = XmlVarFactory.create( + xml_type=XmlType.ELEMENT, + name="a", + qname="a", + types=(TypeB,), + derived=False, + ) + xsi_type = "foo" + namespace = self.meta.namespace + mock_ctx_fetch.return_value = self.meta + mock_xsi_type.return_value = xsi_type + + attrs = {"a": "b"} + ns_map = {"ns0": "xsdata"} + actual = self.node.build_node(var, attrs, ns_map, 10) + + self.assertIsInstance(actual, ElementNode) + self.assertEqual(10, actual.position) + self.assertEqual(DerivedElement, actual.derived_factory) + self.assertIs(mock_ctx_fetch.return_value, actual.meta) + + mock_xsi_type.assert_called_once_with(attrs, ns_map) + mock_ctx_fetch.assert_called_once_with(var.clazz, namespace, xsi_type) + @mock.patch.object(XmlContext, "fetch") def test_build_node_with_dataclass_var_validates_nillable(self, mock_ctx_fetch): var = XmlVarFactory.create(xml_type=XmlType.ELEMENT, qname="a", types=(TypeC,)) diff --git a/xsdata/codegen/handlers/add_attribute_substitutions.py b/xsdata/codegen/handlers/add_attribute_substitutions.py index 9e3b49891..06a140d25 100644 --- a/xsdata/codegen/handlers/add_attribute_substitutions.py +++ b/xsdata/codegen/handlers/add_attribute_substitutions.py @@ -55,11 +55,7 @@ def process_attribute(self, target: Class, attr: Attr): attr_type.substituted = True for substitution in self.substitutions.get(attr_type.qname, []): - attr.restrictions.min_occurs = 0 - - if not attr.restrictions.choice: - attr.restrictions.choice = id(attr) - attr.restrictions.path.append(("c", id(attr), 1, 1)) + self.prepare_substituted(attr) clone = ClassUtils.clone_attribute(substitution, attr.restrictions) clone.restrictions.min_occurs = 0 @@ -81,6 +77,14 @@ def create_substitutions(self): attr = self.create_substitution(obj) self.substitutions[qname].append(attr) + @classmethod + def prepare_substituted(cls, attr: Attr): + attr.restrictions.min_occurs = 0 + if not attr.restrictions.choice: + choice = id(attr) + attr.restrictions.choice = choice + attr.restrictions.path.append(("c", choice, 1, 1)) + @classmethod def create_substitution(cls, source: Class) -> Attr: """Create an attribute with type that refers to the given source class diff --git a/xsdata/codegen/handlers/calculate_attribute_paths.py b/xsdata/codegen/handlers/calculate_attribute_paths.py index ce54e8564..0a1ed88b8 100644 --- a/xsdata/codegen/handlers/calculate_attribute_paths.py +++ b/xsdata/codegen/handlers/calculate_attribute_paths.py @@ -42,6 +42,8 @@ def process_attr_path(cls, attr: Attr): attr.restrictions.choice = index elif name == GROUP: attr.restrictions.group = index + else: + pass min_occurs *= mi max_occurs *= ma diff --git a/xsdata/formats/dataclass/parsers/nodes/element.py b/xsdata/formats/dataclass/parsers/nodes/element.py index 29292396d..16b4a6fb4 100644 --- a/xsdata/formats/dataclass/parsers/nodes/element.py +++ b/xsdata/formats/dataclass/parsers/nodes/element.py @@ -367,10 +367,11 @@ def build_node( if var.clazz: return self.build_element_node( var.clazz, + var.derived, attrs, ns_map, position, - derived_factory if var.derived else None, + derived_factory, xsi_type, xsi_nil, ) @@ -395,10 +396,11 @@ def build_node( if clazz: node = self.build_element_node( clazz, + derived, attrs, ns_map, position, - derived_factory if derived else None, + derived_factory, xsi_type, xsi_nil, ) @@ -417,10 +419,11 @@ def build_node( def build_element_node( self, clazz: Type, + derived: bool, attrs: Dict, ns_map: Dict, position: int, - derived_factory: Optional[Type] = None, + derived_factory: Type, xsi_type: Optional[str] = None, xsi_nil: Optional[bool] = None, ) -> Optional[XmlNode]: @@ -429,6 +432,9 @@ def build_element_node( if not meta or (meta.nillable and xsi_nil is False): return None + if xsi_type and not derived and not issubclass(meta.clazz, clazz): + derived = True + return ElementNode( meta=meta, config=self.config, @@ -436,7 +442,7 @@ def build_element_node( ns_map=ns_map, context=self.context, position=position, - derived_factory=derived_factory, + derived_factory=derived_factory if derived else None, xsi_type=xsi_type, xsi_nil=xsi_nil, mixed=self.meta.mixed_content,