xmlfile.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from __future__ import absolute_import
  2. # Copyright (c) 2010-2015 openpyxl
  3. """Implements the lxml.etree.xmlfile API using the standard library xml.etree"""
  4. from contextlib import contextmanager
  5. from xml.etree.ElementTree import (
  6. Element,
  7. _escape_cdata,
  8. )
  9. from . import incremental_tree
  10. class LxmlSyntaxError(Exception):
  11. pass
  12. class _IncrementalFileWriter(object):
  13. """Replacement for _IncrementalFileWriter of lxml"""
  14. def __init__(self, output_file):
  15. self._element_stack = []
  16. self._file = output_file
  17. self._have_root = False
  18. self.global_nsmap = incremental_tree.current_global_nsmap()
  19. self.is_html = False
  20. @contextmanager
  21. def element(self, tag, attrib=None, nsmap=None, **_extra):
  22. """Create a new xml element using a context manager."""
  23. if nsmap and None in nsmap:
  24. # Normalise None prefix (lxml's default namespace prefix) -> "", as
  25. # required for incremental_tree
  26. if "" in nsmap and nsmap[""] != nsmap[None]:
  27. raise ValueError(
  28. 'Found None and "" as default nsmap prefixes with different URIs'
  29. )
  30. nsmap = nsmap.copy()
  31. nsmap[""] = nsmap.pop(None)
  32. # __enter__ part
  33. self._have_root = True
  34. if attrib is None:
  35. attrib = {}
  36. elem = Element(tag, attrib=attrib, **_extra)
  37. elem.text = ''
  38. elem.tail = ''
  39. if self._element_stack:
  40. is_root = False
  41. (
  42. nsmap_scope,
  43. default_ns_attr_prefix,
  44. uri_to_prefix,
  45. ) = self._element_stack[-1]
  46. else:
  47. is_root = True
  48. nsmap_scope = {}
  49. default_ns_attr_prefix = None
  50. uri_to_prefix = {}
  51. (
  52. tag,
  53. nsmap_scope,
  54. default_ns_attr_prefix,
  55. uri_to_prefix,
  56. next_remains_root,
  57. ) = incremental_tree.write_elem_start(
  58. self._file,
  59. elem,
  60. nsmap_scope=nsmap_scope,
  61. global_nsmap=self.global_nsmap,
  62. short_empty_elements=False,
  63. is_html=self.is_html,
  64. is_root=is_root,
  65. uri_to_prefix=uri_to_prefix,
  66. default_ns_attr_prefix=default_ns_attr_prefix,
  67. new_nsmap=nsmap,
  68. )
  69. self._element_stack.append(
  70. (
  71. nsmap_scope,
  72. default_ns_attr_prefix,
  73. uri_to_prefix,
  74. )
  75. )
  76. yield
  77. # __exit__ part
  78. self._element_stack.pop()
  79. self._file(f"</{tag}>")
  80. if elem.tail:
  81. self._file(_escape_cdata(elem.tail))
  82. def write(self, arg):
  83. """Write a string or subelement."""
  84. if isinstance(arg, str):
  85. # it is not allowed to write a string outside of an element
  86. if not self._element_stack:
  87. raise LxmlSyntaxError()
  88. self._file(_escape_cdata(arg))
  89. else:
  90. if not self._element_stack and self._have_root:
  91. raise LxmlSyntaxError()
  92. if self._element_stack:
  93. is_root = False
  94. (
  95. nsmap_scope,
  96. default_ns_attr_prefix,
  97. uri_to_prefix,
  98. ) = self._element_stack[-1]
  99. else:
  100. is_root = True
  101. nsmap_scope = {}
  102. default_ns_attr_prefix = None
  103. uri_to_prefix = {}
  104. incremental_tree._serialize_ns_xml(
  105. self._file,
  106. arg,
  107. nsmap_scope=nsmap_scope,
  108. global_nsmap=self.global_nsmap,
  109. short_empty_elements=True,
  110. is_html=self.is_html,
  111. is_root=is_root,
  112. uri_to_prefix=uri_to_prefix,
  113. default_ns_attr_prefix=default_ns_attr_prefix,
  114. )
  115. def __enter__(self):
  116. pass
  117. def __exit__(self, type, value, traceback):
  118. # without root the xml document is incomplete
  119. if not self._have_root:
  120. raise LxmlSyntaxError()
  121. class xmlfile(object):
  122. """Context manager that can replace lxml.etree.xmlfile."""
  123. def __init__(self, output_file, buffered=False, encoding="utf-8", close=False):
  124. self._file = output_file
  125. self._close = close
  126. self.encoding = encoding
  127. self.writer_cm = None
  128. def __enter__(self):
  129. self.writer_cm = incremental_tree._get_writer(self._file, encoding=self.encoding)
  130. writer, declared_encoding = self.writer_cm.__enter__()
  131. return _IncrementalFileWriter(writer)
  132. def __exit__(self, type, value, traceback):
  133. if self.writer_cm:
  134. self.writer_cm.__exit__(type, value, traceback)
  135. if self._close:
  136. self._file.close()