Skip to content

Commit

Permalink
Implement iterwalk for xml.etree.ElementTree
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Jun 29, 2021
1 parent e311af3 commit adb7b38
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 23 deletions.
52 changes: 41 additions & 11 deletions tests/formats/dataclass/parsers/handlers/test_native.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import sys
from unittest import mock
from unittest.case import TestCase
from xml import etree

import pytest

from tests import fixtures_dir
from tests.fixtures.books import BookForm
from tests.fixtures.books import Books
from tests.fixtures.books.fixtures import books
from tests.fixtures.books.fixtures import events
Expand All @@ -10,6 +15,7 @@
from xsdata.formats.dataclass.parsers.bases import RecordParser
from xsdata.formats.dataclass.parsers.handlers import XmlEventHandler
from xsdata.formats.dataclass.parsers.handlers import XmlSaxHandler
from xsdata.formats.dataclass.parsers.handlers.native import get_base_url


class XmlEventHandlerTests(TestCase):
Expand All @@ -28,17 +34,6 @@ def test_parse_with_default_ns(self):
self.assertEqual({None: "urn:books"}, self.parser.ns_map)
self.assertEqual(events_default_ns, self.parser.events)

def test_parse_with_xinclude_raises_exception(self):
self.parser.config.process_xinclude = True
path = fixtures_dir.joinpath("books/books.xml")

with self.assertRaises(XmlHandlerError) as cm:
self.parser.from_path(path, Books)

self.assertEqual(
"XmlEventHandler doesn't support xinclude elements.", str(cm.exception)
)

def test_parse_context_with_unhandled_event(self):
context = [("reverse", None)]
handler = XmlEventHandler(parser=self.parser, clazz=Books)
Expand All @@ -48,6 +43,41 @@ def test_parse_context_with_unhandled_event(self):

self.assertEqual("Unhandled event: `reverse`.", str(cm.exception))

def test_parse_with_element_or_tree(self):
path = fixtures_dir.joinpath("books/books.xml")
tree = etree.ElementTree.parse(str(path))

result = self.parser.parse(tree, Books)
self.assertEqual(books, result)

tree = etree.ElementTree.parse(str(path))
result = self.parser.parse(tree.find(".//book"), BookForm)
self.assertEqual(books.book[0], result)

@pytest.mark.skipif(sys.platform == "win32", reason="urljoin + path sep")
def test_parse_with_xinclude(self):
path = fixtures_dir.joinpath("books/books-xinclude.xml")
ns_map = {"ns0": "urn:books"}

self.parser.config.process_xinclude = True
self.assertEqual(books, self.parser.parse(str(path), Books))
self.assertEqual(ns_map, self.parser.ns_map)

@pytest.mark.skipif(sys.platform == "win32", reason="urljoin + path sep")
def test_parse_with_xinclude_from_memory(self):
path = fixtures_dir.joinpath("books/books-xinclude.xml")
ns_map = {"ns0": "urn:books"}

self.parser.config.process_xinclude = True
self.parser.config.base_url = str(path)
self.assertEqual(books, self.parser.from_string(path.read_text(), Books))
self.assertEqual(ns_map, self.parser.ns_map)

def test_get_base_url(self):
self.assertIsNone(get_base_url(None, None))
self.assertIsNone(get_base_url(None, None))
self.assertEqual("config/", get_base_url("config/", "/tmp/foo.xml"))


class SaxHandlerTests(TestCase):
def setUp(self):
Expand Down
82 changes: 74 additions & 8 deletions xsdata/formats/dataclass/parsers/handlers/native.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import functools
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Optional
from typing import Tuple
from typing import Type
from urllib.parse import urljoin
from xml import sax
from xml.etree.ElementTree import iterparse
from xml.etree import ElementInclude as xinclude
from xml.etree import ElementTree as etree

from xsdata.exceptions import XmlHandlerError
from xsdata.formats.dataclass.parsers.mixins import PushParser
from xsdata.formats.dataclass.parsers.mixins import SaxHandler
from xsdata.formats.dataclass.parsers.mixins import XmlHandler
from xsdata.models.enums import EventType
from xsdata.utils import namespaces
from xsdata.utils.namespaces import build_qname

EVENTS = (EventType.START, EventType.END, EventType.START_NS)
Expand All @@ -29,16 +34,35 @@ class XmlEventHandler(XmlHandler):

def parse(self, source: Any) -> Any:
"""
Parse an XML document from a system identifier or an InputSource.
Parse an XML document from a system identifier or an InputSource or
directly from an xml Element or ElementTree.
:raises XmlHandlerError: If process xinclude config is enabled.
When source is an Element or ElementTree the handler will walk
over the objects structure.
When source is a system identifier or an InputSource the parser
will ignore comments and recover from errors.
When config process_xinclude is enabled the handler will parse
the whole document and then walk down the element tree.
"""
if self.parser.config.process_xinclude:
raise XmlHandlerError(
f"{type(self).__name__} doesn't support xinclude elements."
)

return self.process_context(iterparse(source, EVENTS)) # nosec
if isinstance(source, etree.ElementTree):
source = source.getroot()

if isinstance(source, etree.Element):
ctx = iterwalk(source, {})
elif self.parser.config.process_xinclude:
root = etree.parse(source).getroot() # nosec
base_url = get_base_url(self.parser.config.base_url, source)
loader = functools.partial(xinclude_loader, base_url=base_url)

xinclude.include(root, loader=loader)
ctx = iterwalk(root, {})
else:
ctx = etree.iterparse(source, EVENTS) # nosec

return self.process_context(ctx)

def process_context(self, context: Iterable) -> Any:
"""Iterate context and push the events to main parser."""
Expand Down Expand Up @@ -154,3 +178,45 @@ def startPrefixMapping(self, prefix: str, uri: Optional[str]):
:param uri: Namespace uri
"""
self.ns_map[prefix] = uri or ""


def iterwalk(element: etree.Element, ns_map: Dict) -> Iterator[Tuple[str, Any]]:
"""
Walk over the element tree structure and emit start-ns/start/end events.
The ElementTree doesn't preserve the original namespace prefixes, we
have to generate new ones.
"""
uri = namespaces.target_uri(element.tag)
if uri is not None:
prefix = namespaces.load_prefix(uri, ns_map)
yield EventType.START_NS, (prefix, uri)

yield EventType.START, element

for child in element:
yield from iterwalk(child, ns_map)

yield EventType.END, element


def get_base_url(base_url: Optional[str], source: Any) -> Optional[str]:

if base_url:
return base_url

return source if isinstance(source, str) else None


def xinclude_loader(
href: str,
parse: str,
encoding: Optional[str] = None,
base_url: Optional[str] = None,
) -> Any:
"""Custom loader for xinclude to support base_url argument that doesn't
exist for python < 3.9."""
if base_url:
href = urljoin(base_url, href)

return xinclude.default_loader(href, parse, encoding)
5 changes: 1 addition & 4 deletions xsdata/formats/dataclass/serializers/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,7 @@ def encode(cls, value: Any, var: XmlVar) -> Any:
if isinstance(value, (str, QName)) or var is None:
return value

if isinstance(value, tuple) and hasattr(value, "_fields"):
return converter.serialize(value, format=var.format)

if isinstance(value, (tuple, list)):
if collections.is_array(value):
return [cls.encode(v, var) for v in value]

if isinstance(value, Enum):
Expand Down

0 comments on commit adb7b38

Please sign in to comment.