stylesheet.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # Copyright (c) 2010-2024 openpyxl
  2. from warnings import warn
  3. from openpyxl.descriptors.serialisable import Serialisable
  4. from openpyxl.descriptors import (
  5. Typed,
  6. )
  7. from openpyxl.descriptors.sequence import NestedSequence
  8. from openpyxl.descriptors.excel import ExtensionList
  9. from openpyxl.utils.indexed_list import IndexedList
  10. from openpyxl.xml.constants import ARC_STYLE, SHEET_MAIN_NS
  11. from openpyxl.xml.functions import fromstring
  12. from .builtins import styles
  13. from .colors import ColorList
  14. from .differential import DifferentialStyle
  15. from .table import TableStyleList
  16. from .borders import Border
  17. from .fills import Fill
  18. from .fonts import Font
  19. from .numbers import (
  20. NumberFormatList,
  21. BUILTIN_FORMATS,
  22. BUILTIN_FORMATS_MAX_SIZE,
  23. BUILTIN_FORMATS_REVERSE,
  24. is_date_format,
  25. is_timedelta_format,
  26. builtin_format_code
  27. )
  28. from .named_styles import (
  29. _NamedCellStyleList,
  30. NamedStyleList,
  31. NamedStyle,
  32. )
  33. from .cell_style import CellStyle, CellStyleList
  34. class Stylesheet(Serialisable):
  35. tagname = "styleSheet"
  36. numFmts = Typed(expected_type=NumberFormatList)
  37. fonts = NestedSequence(expected_type=Font, count=True)
  38. fills = NestedSequence(expected_type=Fill, count=True)
  39. borders = NestedSequence(expected_type=Border, count=True)
  40. cellStyleXfs = Typed(expected_type=CellStyleList)
  41. cellXfs = Typed(expected_type=CellStyleList)
  42. cellStyles = Typed(expected_type=_NamedCellStyleList)
  43. dxfs = NestedSequence(expected_type=DifferentialStyle, count=True)
  44. tableStyles = Typed(expected_type=TableStyleList, allow_none=True)
  45. colors = Typed(expected_type=ColorList, allow_none=True)
  46. extLst = Typed(expected_type=ExtensionList, allow_none=True)
  47. __elements__ = ('numFmts', 'fonts', 'fills', 'borders', 'cellStyleXfs',
  48. 'cellXfs', 'cellStyles', 'dxfs', 'tableStyles', 'colors')
  49. def __init__(self,
  50. numFmts=None,
  51. fonts=(),
  52. fills=(),
  53. borders=(),
  54. cellStyleXfs=None,
  55. cellXfs=None,
  56. cellStyles=None,
  57. dxfs=(),
  58. tableStyles=None,
  59. colors=None,
  60. extLst=None,
  61. ):
  62. if numFmts is None:
  63. numFmts = NumberFormatList()
  64. self.numFmts = numFmts
  65. self.number_formats = IndexedList()
  66. self.fonts = fonts
  67. self.fills = fills
  68. self.borders = borders
  69. if cellStyleXfs is None:
  70. cellStyleXfs = CellStyleList()
  71. self.cellStyleXfs = cellStyleXfs
  72. if cellXfs is None:
  73. cellXfs = CellStyleList()
  74. self.cellXfs = cellXfs
  75. if cellStyles is None:
  76. cellStyles = _NamedCellStyleList()
  77. self.cellStyles = cellStyles
  78. self.dxfs = dxfs
  79. self.tableStyles = tableStyles
  80. self.colors = colors
  81. self.cell_styles = self.cellXfs._to_array()
  82. self.alignments = self.cellXfs.alignments
  83. self.protections = self.cellXfs.prots
  84. self._normalise_numbers()
  85. self.named_styles = self._merge_named_styles()
  86. @classmethod
  87. def from_tree(cls, node):
  88. # strip all attribs
  89. attrs = dict(node.attrib)
  90. for k in attrs:
  91. del node.attrib[k]
  92. return super().from_tree(node)
  93. def _merge_named_styles(self):
  94. """
  95. Merge named style names "cellStyles" with their associated styles
  96. "cellStyleXfs"
  97. """
  98. style_refs = self.cellStyles.remove_duplicates()
  99. from_ref = [self._expand_named_style(style_ref) for style_ref in style_refs]
  100. return NamedStyleList(from_ref)
  101. def _expand_named_style(self, style_ref):
  102. """
  103. Expand a named style reference element to a
  104. named style object by binding the relevant
  105. objects from the stylesheet
  106. """
  107. xf = self.cellStyleXfs[style_ref.xfId]
  108. named_style = NamedStyle(
  109. name=style_ref.name,
  110. hidden=style_ref.hidden,
  111. builtinId=style_ref.builtinId,
  112. )
  113. named_style.font = self.fonts[xf.fontId]
  114. named_style.fill = self.fills[xf.fillId]
  115. named_style.border = self.borders[xf.borderId]
  116. if xf.numFmtId < BUILTIN_FORMATS_MAX_SIZE:
  117. formats = BUILTIN_FORMATS
  118. else:
  119. formats = self.custom_formats
  120. if xf.numFmtId in formats:
  121. named_style.number_format = formats[xf.numFmtId]
  122. if xf.alignment:
  123. named_style.alignment = xf.alignment
  124. if xf.protection:
  125. named_style.protection = xf.protection
  126. return named_style
  127. def _split_named_styles(self, wb):
  128. """
  129. Convert NamedStyle into separate CellStyle and Xf objects
  130. """
  131. for style in wb._named_styles:
  132. self.cellStyles.cellStyle.append(style.as_name())
  133. self.cellStyleXfs.xf.append(style.as_xf())
  134. @property
  135. def custom_formats(self):
  136. return dict([(n.numFmtId, n.formatCode) for n in self.numFmts.numFmt])
  137. def _normalise_numbers(self):
  138. """
  139. Rebase custom numFmtIds with a floor of 164 when reading stylesheet
  140. And index datetime formats
  141. """
  142. date_formats = set()
  143. timedelta_formats = set()
  144. custom = self.custom_formats
  145. formats = self.number_formats
  146. for idx, style in enumerate(self.cell_styles):
  147. if style.numFmtId in custom:
  148. fmt = custom[style.numFmtId]
  149. if fmt in BUILTIN_FORMATS_REVERSE: # remove builtins
  150. style.numFmtId = BUILTIN_FORMATS_REVERSE[fmt]
  151. else:
  152. style.numFmtId = formats.add(fmt) + BUILTIN_FORMATS_MAX_SIZE
  153. else:
  154. fmt = builtin_format_code(style.numFmtId)
  155. if is_date_format(fmt):
  156. # Create an index of which styles refer to datetimes
  157. date_formats.add(idx)
  158. if is_timedelta_format(fmt):
  159. # Create an index of which styles refer to timedeltas
  160. timedelta_formats.add(idx)
  161. self.date_formats = date_formats
  162. self.timedelta_formats = timedelta_formats
  163. def to_tree(self, tagname=None, idx=None, namespace=None):
  164. tree = super().to_tree(tagname, idx, namespace)
  165. tree.set("xmlns", SHEET_MAIN_NS)
  166. return tree
  167. def apply_stylesheet(archive, wb):
  168. """
  169. Add styles to workbook if present
  170. """
  171. try:
  172. src = archive.read(ARC_STYLE)
  173. except KeyError:
  174. return wb
  175. node = fromstring(src)
  176. stylesheet = Stylesheet.from_tree(node)
  177. if stylesheet.cell_styles:
  178. wb._borders = IndexedList(stylesheet.borders)
  179. wb._fonts = IndexedList(stylesheet.fonts)
  180. wb._fills = IndexedList(stylesheet.fills)
  181. wb._differential_styles.styles = stylesheet.dxfs
  182. wb._number_formats = stylesheet.number_formats
  183. wb._protections = stylesheet.protections
  184. wb._alignments = stylesheet.alignments
  185. wb._table_styles = stylesheet.tableStyles
  186. # need to overwrite openpyxl defaults in case workbook has different ones
  187. wb._cell_styles = stylesheet.cell_styles
  188. wb._named_styles = stylesheet.named_styles
  189. wb._date_formats = stylesheet.date_formats
  190. wb._timedelta_formats = stylesheet.timedelta_formats
  191. for ns in wb._named_styles:
  192. ns.bind(wb)
  193. else:
  194. warn("Workbook contains no stylesheet, using openpyxl's defaults")
  195. if not wb._named_styles:
  196. normal = styles['Normal']
  197. wb.add_named_style(normal)
  198. warn("Workbook contains no default style, apply openpyxl's default")
  199. if stylesheet.colors is not None:
  200. wb._colors = stylesheet.colors.index
  201. def write_stylesheet(wb):
  202. stylesheet = Stylesheet()
  203. stylesheet.fonts = wb._fonts
  204. stylesheet.fills = wb._fills
  205. stylesheet.borders = wb._borders
  206. stylesheet.dxfs = wb._differential_styles.styles
  207. stylesheet.colors = ColorList(indexedColors=wb._colors)
  208. from .numbers import NumberFormat
  209. fmts = []
  210. for idx, code in enumerate(wb._number_formats, BUILTIN_FORMATS_MAX_SIZE):
  211. fmt = NumberFormat(idx, code)
  212. fmts.append(fmt)
  213. stylesheet.numFmts.numFmt = fmts
  214. xfs = []
  215. for style in wb._cell_styles:
  216. xf = CellStyle.from_array(style)
  217. if style.alignmentId:
  218. xf.alignment = wb._alignments[style.alignmentId]
  219. if style.protectionId:
  220. xf.protection = wb._protections[style.protectionId]
  221. xfs.append(xf)
  222. stylesheet.cellXfs = CellStyleList(xf=xfs)
  223. stylesheet._split_named_styles(wb)
  224. stylesheet.tableStyles = wb._table_styles
  225. return stylesheet.to_tree()