_chart.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # Copyright (c) 2010-2024 openpyxl
  2. from collections import OrderedDict
  3. from operator import attrgetter
  4. from openpyxl.descriptors import (
  5. Typed,
  6. Integer,
  7. Alias,
  8. MinMax,
  9. Bool,
  10. Set,
  11. )
  12. from openpyxl.descriptors.sequence import ValueSequence
  13. from openpyxl.descriptors.serialisable import Serialisable
  14. from ._3d import _3DBase
  15. from .data_source import AxDataSource, NumRef
  16. from .layout import Layout
  17. from .legend import Legend
  18. from .reference import Reference
  19. from .series_factory import SeriesFactory
  20. from .series import attribute_mapping
  21. from .shapes import GraphicalProperties
  22. from .title import TitleDescriptor
  23. class AxId(Serialisable):
  24. val = Integer()
  25. def __init__(self, val):
  26. self.val = val
  27. def PlotArea():
  28. from .chartspace import PlotArea
  29. return PlotArea()
  30. class ChartBase(Serialisable):
  31. """
  32. Base class for all charts
  33. """
  34. legend = Typed(expected_type=Legend, allow_none=True)
  35. layout = Typed(expected_type=Layout, allow_none=True)
  36. roundedCorners = Bool(allow_none=True)
  37. axId = ValueSequence(expected_type=int)
  38. visible_cells_only = Bool(allow_none=True)
  39. display_blanks = Set(values=['span', 'gap', 'zero'])
  40. graphical_properties = Typed(expected_type=GraphicalProperties, allow_none=True)
  41. _series_type = ""
  42. ser = ()
  43. series = Alias('ser')
  44. title = TitleDescriptor()
  45. anchor = "E15" # default anchor position
  46. width = 15 # in cm, approx 5 rows
  47. height = 7.5 # in cm, approx 14 rows
  48. _id = 1
  49. _path = "/xl/charts/chart{0}.xml"
  50. style = MinMax(allow_none=True, min=1, max=48)
  51. mime_type = "application/vnd.openxmlformats-officedocument.drawingml.chart+xml"
  52. graphical_properties = Typed(expected_type=GraphicalProperties, allow_none=True) # mapped to chartspace
  53. __elements__ = ()
  54. def __init__(self, axId=(), **kw):
  55. self._charts = [self]
  56. self.title = None
  57. self.layout = None
  58. self.roundedCorners = None
  59. self.legend = Legend()
  60. self.graphical_properties = None
  61. self.style = None
  62. self.plot_area = PlotArea()
  63. self.axId = axId
  64. self.display_blanks = 'gap'
  65. self.pivotSource = None
  66. self.pivotFormats = ()
  67. self.visible_cells_only = True
  68. self.idx_base = 0
  69. self.graphical_properties = None
  70. super().__init__()
  71. def __hash__(self):
  72. """
  73. Just need to check for identity
  74. """
  75. return id(self)
  76. def __iadd__(self, other):
  77. """
  78. Combine the chart with another one
  79. """
  80. if not isinstance(other, ChartBase):
  81. raise TypeError("Only other charts can be added")
  82. self._charts.append(other)
  83. return self
  84. def to_tree(self, namespace=None, tagname=None, idx=None):
  85. self.axId = [id for id in self._axes]
  86. if self.ser is not None:
  87. for s in self.ser:
  88. s.__elements__ = attribute_mapping[self._series_type]
  89. return super().to_tree(tagname, idx)
  90. def _reindex(self):
  91. """
  92. Normalise and rebase series: sort by order and then rebase order
  93. """
  94. # sort data series in order and rebase
  95. ds = sorted(self.series, key=attrgetter("order"))
  96. for idx, s in enumerate(ds):
  97. s.order = idx
  98. self.series = ds
  99. def _write(self):
  100. from .chartspace import ChartSpace, ChartContainer
  101. self.plot_area.layout = self.layout
  102. idx_base = self.idx_base
  103. for chart in self._charts:
  104. if chart not in self.plot_area._charts:
  105. chart.idx_base = idx_base
  106. idx_base += len(chart.series)
  107. self.plot_area._charts = self._charts
  108. container = ChartContainer(plotArea=self.plot_area, legend=self.legend, title=self.title)
  109. if isinstance(chart, _3DBase):
  110. container.view3D = chart.view3D
  111. container.floor = chart.floor
  112. container.sideWall = chart.sideWall
  113. container.backWall = chart.backWall
  114. container.plotVisOnly = self.visible_cells_only
  115. container.dispBlanksAs = self.display_blanks
  116. container.pivotFmts = self.pivotFormats
  117. cs = ChartSpace(chart=container)
  118. cs.style = self.style
  119. cs.roundedCorners = self.roundedCorners
  120. cs.pivotSource = self.pivotSource
  121. cs.spPr = self.graphical_properties
  122. return cs.to_tree()
  123. @property
  124. def _axes(self):
  125. x = getattr(self, "x_axis", None)
  126. y = getattr(self, "y_axis", None)
  127. z = getattr(self, "z_axis", None)
  128. return OrderedDict([(axis.axId, axis) for axis in (x, y, z) if axis])
  129. def set_categories(self, labels):
  130. """
  131. Set the categories / x-axis values
  132. """
  133. if not isinstance(labels, Reference):
  134. labels = Reference(range_string=labels)
  135. for s in self.ser:
  136. s.cat = AxDataSource(numRef=NumRef(f=labels))
  137. def add_data(self, data, from_rows=False, titles_from_data=False):
  138. """
  139. Add a range of data in a single pass.
  140. The default is to treat each column as a data series.
  141. """
  142. if not isinstance(data, Reference):
  143. data = Reference(range_string=data)
  144. if from_rows:
  145. values = data.rows
  146. else:
  147. values = data.cols
  148. for ref in values:
  149. series = SeriesFactory(ref, title_from_data=titles_from_data)
  150. self.series.append(series)
  151. def append(self, value):
  152. """Append a data series to the chart"""
  153. l = self.series[:]
  154. l.append(value)
  155. self.series = l
  156. @property
  157. def path(self):
  158. return self._path.format(self._id)