Skip to content

Commit

Permalink
Add support for custom decorators and base classes
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed May 28, 2023
1 parent 5eaf612 commit c00e4f8
Show file tree
Hide file tree
Showing 9 changed files with 343 additions and 13 deletions.
3 changes: 3 additions & 0 deletions docs/api/codegen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ like naming conventions and substitutions.
OutputFormat
GeneratorConventions
GeneratorSubstitutions
GeneratorExtensions
StructureStyle
DocstringStyle
ClassFilterStrategy
CompoundFields
ObjectType
ExtensionType
GeneratorSubstitution
GeneratorExtension
NameConvention
NameCase
1 change: 1 addition & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Code Generation
examples/dtd-modeling
examples/compound-fields
examples/dataclasses-features
examples/extending-models


Data Binding
Expand Down
49 changes: 49 additions & 0 deletions docs/examples/extending-models.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
================
Extending Models
================



Creating subclasses from the generated models require to repeat the original `Meta`
class and it's generally discouraged. You can instead apply base classes and
decorators through the code generation configuration.


The following configuration will add a base class and a decorator to all the
generated classes.

Read :ref:`more <GeneratorExtension>`.


.. code-block:: xml
<?xml version="1.0" encoding="UTF-8"?>
<Config xmlns="http://pypi.org/project/xsdata" version="23.6">
<Extensions>
<Extension type="class" class=".*" import="dataclasses_jsonschema.JsonSchemaMixin" prepend="false" applyIfDerived="false"/>
<Extension type="decorator" class=".*" import="typed_dataclass.typed_dataclass" prepend="false" applyIfDerived="false"/>
</Extensions>
</Config>
.. code-block:: python
from dataclasses import dataclass, field
from dataclasses_jsonschema import JsonSchemaMixin
from typed_dataclass import typed_dataclass
from typing import Optional
@dataclass
@typed_dataclass
class Cores(JsonSchemaMixin):
class Meta:
name = "cores"
core: Optional[str] = field(
default=None,
metadata={
"type": "Element",
"required": True,
}
)
1 change: 1 addition & 0 deletions tests/fixtures/stripe/.xsdata.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
</Conventions>
<Aliases/>
<Substitutions/>
<Extensions />
</Config>
122 changes: 122 additions & 0 deletions tests/formats/dataclass/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from xsdata.codegen.models import Restrictions
from xsdata.formats.dataclass.filters import Filters
from xsdata.models.config import DocstringStyle
from xsdata.models.config import ExtensionType
from xsdata.models.config import GeneratorConfig
from xsdata.models.config import GeneratorExtension
from xsdata.models.config import GeneratorSubstitution
from xsdata.models.config import NameCase
from xsdata.models.config import ObjectType
Expand All @@ -14,6 +16,8 @@
from xsdata.models.enums import Tag
from xsdata.utils.testing import AttrFactory
from xsdata.utils.testing import AttrTypeFactory
from xsdata.utils.testing import ClassFactory
from xsdata.utils.testing import ExtensionFactory
from xsdata.utils.testing import FactoryTestCase

type_str = AttrTypeFactory.native(DataType.STRING)
Expand Down Expand Up @@ -42,6 +46,102 @@ def test_class_name(self):
self.assertEqual("TypeType", self.filters.class_name(".*"))
self.assertEqual("Cbad", self.filters.class_name("abcd"))

def test_class_bases(self):
etp = ExtensionType.CLASS
self.filters.extensions[etp] = [
GeneratorExtension(
type=etp,
class_name=".*Bar",
import_string="a.b",
apply_if_derived=True,
prepend=False,
),
GeneratorExtension(
type=etp,
class_name="Foo.*",
import_string="a.b",
apply_if_derived=True,
prepend=True,
),
GeneratorExtension(
type=etp,
class_name="Foo.*",
import_string="a.c",
apply_if_derived=True,
prepend=True,
),
GeneratorExtension(
type=etp,
class_name="Foo.*",
import_string="a.d",
apply_if_derived=False,
prepend=True,
),
GeneratorExtension(
type=etp,
class_name="Nope.*",
import_string="a.e",
apply_if_derived=True,
prepend=False,
),
]
target = ClassFactory.create(extensions=ExtensionFactory.list(1))

expected = self.filters.class_bases(target, "FooBar")
self.assertEqual(["c", "b", "AttrB"], expected)

target.extensions.clear()
expected = self.filters.class_bases(target, "FooBar")
self.assertEqual(["d", "c", "b"], expected)

def test_class_annotations(self):
etp = ExtensionType.DECORATOR
self.filters.extensions[etp] = [
GeneratorExtension(
type=etp,
class_name=".*Bar",
import_string="a.b",
apply_if_derived=True,
prepend=False,
),
GeneratorExtension(
type=etp,
class_name="Foo.*",
import_string="a.b",
apply_if_derived=True,
prepend=True,
),
GeneratorExtension(
type=etp,
class_name="Foo.*",
import_string="a.c",
apply_if_derived=True,
prepend=True,
),
GeneratorExtension(
type=etp,
class_name="Foo.*",
import_string="a.d",
apply_if_derived=False,
prepend=False,
),
GeneratorExtension(
type=etp,
class_name="Nope.*",
import_string="a.e",
apply_if_derived=True,
prepend=False,
),
]
target = ClassFactory.create(extensions=ExtensionFactory.list(1))

expected = self.filters.class_annotations(target, "FooBar")
self.assertEqual(["@c", "@b", "@dataclass"], expected)

target.extensions.clear()
expected = self.filters.class_annotations(target, "FooBar")
self.assertEqual(["@c", "@b", "@dataclass", "@d"], expected)

def test_field_name(self):
self.filters.substitutions[ObjectType.FIELD]["abc"] = "cba"

Expand Down Expand Up @@ -817,6 +917,14 @@ def test__init(self):
config.substitutions.substitution.append(
GeneratorSubstitution(ObjectType.PACKAGE, "m", "n")
)
config.extensions.extension.extend(
[
GeneratorExtension(ExtensionType.DECORATOR, "a", "a.b"),
GeneratorExtension(ExtensionType.DECORATOR, "b", "a.c"),
GeneratorExtension(ExtensionType.CLASS, "c", "a.d"),
GeneratorExtension(ExtensionType.CLASS, "d", "a.e"),
]
)

filters = Filters(config)

Expand All @@ -839,3 +947,17 @@ def test__init(self):
ObjectType.PACKAGE: {"m": "n"},
}
self.assertEqual(expected_substitutions, filters.substitutions)

expected_extensions = {
ExtensionType.DECORATOR: config.extensions.extension[0:2],
ExtensionType.CLASS: config.extensions.extension[2:4],
}
self.assertEqual(expected_extensions, filters.extensions)

expected_imports = {
"b": {"@b"},
"c": {"@c"},
"d": {"(d", " d)"},
"e": {"(e", " e)"},
}
self.assertEqual(expected_imports, filters.import_patterns["a"])
12 changes: 12 additions & 0 deletions tests/models/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from xsdata import __version__
from xsdata.exceptions import GeneratorConfigError
from xsdata.exceptions import ParserError
from xsdata.models.config import ExtensionType
from xsdata.models.config import GeneratorAlias
from xsdata.models.config import GeneratorAliases
from xsdata.models.config import GeneratorConfig
from xsdata.models.config import GeneratorExtension
from xsdata.models.config import GeneratorOutput
from xsdata.models.config import ObjectType
from xsdata.models.config import OutputFormat
Expand Down Expand Up @@ -60,6 +62,7 @@ def test_create(self):
' <Substitution type="package" search="http://schemas.xmlsoap.org/soap/envelope/" replace="soapenv"/>\n'
' <Substitution type="class" search="(.*)Class$" replace="\\1Type"/>\n'
" </Substitutions>\n"
" <Extensions/>\n"
"</Config>\n"
)
self.assertEqual(expected, file_path.read_text())
Expand Down Expand Up @@ -111,6 +114,7 @@ def test_read(self):
" </Conventions>\n"
" <Aliases/>\n"
" <Substitutions/>\n"
" <Extensions/>\n"
"</Config>\n"
)
self.assertEqual(expected, file_path.read_text())
Expand Down Expand Up @@ -209,3 +213,11 @@ def test_init_config_with_aliases(self):
config = GeneratorConfig.read(output_path)
self.assertIsNone(config.aliases)
self.assertEqual(4, len(config.substitutions.substitution))

def test_extension_with_invalid_import_string(self):
cases = [None, "", "bar"]
for case in cases:
with self.assertRaises(GeneratorConfigError) as cm:
GeneratorExtension(type=ExtensionType.DECORATOR, import_string=case)

self.assertEqual(f"Invalid extension import '{case}'", str(cm.exception))
Loading

0 comments on commit c00e4f8

Please sign in to comment.