serialisable.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # Copyright (c) 2010-2024 openpyxl
  2. from copy import copy
  3. from keyword import kwlist
  4. KEYWORDS = frozenset(kwlist)
  5. from . import Descriptor
  6. from . import MetaSerialisable
  7. from .sequence import (
  8. Sequence,
  9. NestedSequence,
  10. MultiSequencePart,
  11. )
  12. from .namespace import namespaced
  13. from openpyxl.compat import safe_string
  14. from openpyxl.xml.functions import (
  15. Element,
  16. localname,
  17. )
  18. seq_types = (list, tuple)
  19. class Serialisable(metaclass=MetaSerialisable):
  20. """
  21. Objects can serialise to XML their attributes and child objects.
  22. The following class attributes are created by the metaclass at runtime:
  23. __attrs__ = attributes
  24. __nested__ = single-valued child treated as an attribute
  25. __elements__ = child elements
  26. """
  27. __attrs__ = None
  28. __nested__ = None
  29. __elements__ = None
  30. __namespaced__ = None
  31. idx_base = 0
  32. @property
  33. def tagname(self):
  34. raise(NotImplementedError)
  35. namespace = None
  36. @classmethod
  37. def from_tree(cls, node):
  38. """
  39. Create object from XML
  40. """
  41. # strip known namespaces from attributes
  42. attrib = dict(node.attrib)
  43. for key, ns in cls.__namespaced__:
  44. if ns in attrib:
  45. attrib[key] = attrib[ns]
  46. del attrib[ns]
  47. # strip attributes with unknown namespaces
  48. for key in list(attrib):
  49. if key.startswith('{'):
  50. del attrib[key]
  51. elif key in KEYWORDS:
  52. attrib["_" + key] = attrib[key]
  53. del attrib[key]
  54. elif "-" in key:
  55. n = key.replace("-", "_")
  56. attrib[n] = attrib[key]
  57. del attrib[key]
  58. if node.text and "attr_text" in cls.__attrs__:
  59. attrib["attr_text"] = node.text
  60. for el in node:
  61. tag = localname(el)
  62. if tag in KEYWORDS:
  63. tag = "_" + tag
  64. desc = getattr(cls, tag, None)
  65. if desc is None or isinstance(desc, property):
  66. continue
  67. if hasattr(desc, 'from_tree'):
  68. #descriptor manages conversion
  69. obj = desc.from_tree(el)
  70. else:
  71. if hasattr(desc.expected_type, "from_tree"):
  72. #complex type
  73. obj = desc.expected_type.from_tree(el)
  74. else:
  75. #primitive
  76. obj = el.text
  77. if isinstance(desc, NestedSequence):
  78. attrib[tag] = obj
  79. elif isinstance(desc, Sequence):
  80. attrib.setdefault(tag, [])
  81. attrib[tag].append(obj)
  82. elif isinstance(desc, MultiSequencePart):
  83. attrib.setdefault(desc.store, [])
  84. attrib[desc.store].append(obj)
  85. else:
  86. attrib[tag] = obj
  87. return cls(**attrib)
  88. def to_tree(self, tagname=None, idx=None, namespace=None):
  89. if tagname is None:
  90. tagname = self.tagname
  91. # keywords have to be masked
  92. if tagname.startswith("_"):
  93. tagname = tagname[1:]
  94. tagname = namespaced(self, tagname, namespace)
  95. namespace = getattr(self, "namespace", namespace)
  96. attrs = dict(self)
  97. for key, ns in self.__namespaced__:
  98. if key in attrs:
  99. attrs[ns] = attrs[key]
  100. del attrs[key]
  101. el = Element(tagname, attrs)
  102. if "attr_text" in self.__attrs__:
  103. el.text = safe_string(getattr(self, "attr_text"))
  104. for child_tag in self.__elements__:
  105. desc = getattr(self.__class__, child_tag, None)
  106. obj = getattr(self, child_tag)
  107. if hasattr(desc, "namespace") and hasattr(obj, 'namespace'):
  108. obj.namespace = desc.namespace
  109. if isinstance(obj, seq_types):
  110. if isinstance(desc, NestedSequence):
  111. # wrap sequence in container
  112. if not obj:
  113. continue
  114. nodes = [desc.to_tree(child_tag, obj, namespace)]
  115. elif isinstance(desc, Sequence):
  116. # sequence
  117. desc.idx_base = self.idx_base
  118. nodes = (desc.to_tree(child_tag, obj, namespace))
  119. else: # property
  120. nodes = (v.to_tree(child_tag, namespace) for v in obj)
  121. for node in nodes:
  122. el.append(node)
  123. else:
  124. if child_tag in self.__nested__:
  125. node = desc.to_tree(child_tag, obj, namespace)
  126. elif obj is None:
  127. continue
  128. else:
  129. node = obj.to_tree(child_tag)
  130. if node is not None:
  131. el.append(node)
  132. return el
  133. def __iter__(self):
  134. for attr in self.__attrs__:
  135. value = getattr(self, attr)
  136. if attr.startswith("_"):
  137. attr = attr[1:]
  138. elif attr != "attr_text" and "_" in attr:
  139. desc = getattr(self.__class__, attr)
  140. if getattr(desc, "hyphenated", False):
  141. attr = attr.replace("_", "-")
  142. if attr != "attr_text" and value is not None:
  143. yield attr, safe_string(value)
  144. def __eq__(self, other):
  145. if not self.__class__ == other.__class__:
  146. return False
  147. elif not dict(self) == dict(other):
  148. return False
  149. for el in self.__elements__:
  150. if getattr(self, el) != getattr(other, el):
  151. return False
  152. return True
  153. def __ne__(self, other):
  154. return not self == other
  155. def __repr__(self):
  156. s = u"<{0}.{1} object>\nParameters:".format(
  157. self.__module__,
  158. self.__class__.__name__
  159. )
  160. args = []
  161. for k in self.__attrs__ + self.__elements__:
  162. v = getattr(self, k)
  163. if isinstance(v, Descriptor):
  164. v = None
  165. args.append(u"{0}={1}".format(k, repr(v)))
  166. args = u", ".join(args)
  167. return u"\n".join([s, args])
  168. def __hash__(self):
  169. fields = []
  170. for attr in self.__attrs__ + self.__elements__:
  171. val = getattr(self, attr)
  172. if isinstance(val, list):
  173. val = tuple(val)
  174. fields.append(val)
  175. return hash(tuple(fields))
  176. def __add__(self, other):
  177. if type(self) != type(other):
  178. raise TypeError("Cannot combine instances of different types")
  179. vals = {}
  180. for attr in self.__attrs__:
  181. vals[attr] = getattr(self, attr) or getattr(other, attr)
  182. for el in self.__elements__:
  183. a = getattr(self, el)
  184. b = getattr(other, el)
  185. if a and b:
  186. vals[el] = a + b
  187. else:
  188. vals[el] = a or b
  189. return self.__class__(**vals)
  190. def __copy__(self):
  191. # serialise to xml and back to avoid shallow copies
  192. xml = self.to_tree(tagname="dummy")
  193. cp = self.__class__.from_tree(xml)
  194. # copy any non-persisted attributed
  195. for k in self.__dict__:
  196. if k not in self.__attrs__ + self.__elements__:
  197. v = copy(getattr(self, k))
  198. setattr(cp, k, v)
  199. return cp