plotarea.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright (c) 2010-2024 openpyxl
  2. from openpyxl.descriptors.serialisable import Serialisable
  3. from openpyxl.descriptors import (
  4. Typed,
  5. Alias,
  6. )
  7. from openpyxl.descriptors.excel import (
  8. ExtensionList,
  9. )
  10. from openpyxl.descriptors.sequence import (
  11. MultiSequence,
  12. MultiSequencePart,
  13. )
  14. from openpyxl.descriptors.nested import (
  15. NestedBool,
  16. )
  17. from ._3d import _3DBase
  18. from .area_chart import AreaChart, AreaChart3D
  19. from .bar_chart import BarChart, BarChart3D
  20. from .bubble_chart import BubbleChart
  21. from .line_chart import LineChart, LineChart3D
  22. from .pie_chart import PieChart, PieChart3D, ProjectedPieChart, DoughnutChart
  23. from .radar_chart import RadarChart
  24. from .scatter_chart import ScatterChart
  25. from .stock_chart import StockChart
  26. from .surface_chart import SurfaceChart, SurfaceChart3D
  27. from .layout import Layout
  28. from .shapes import GraphicalProperties
  29. from .text import RichText
  30. from .axis import (
  31. NumericAxis,
  32. TextAxis,
  33. SeriesAxis,
  34. DateAxis,
  35. )
  36. class DataTable(Serialisable):
  37. tagname = "dTable"
  38. showHorzBorder = NestedBool(allow_none=True)
  39. showVertBorder = NestedBool(allow_none=True)
  40. showOutline = NestedBool(allow_none=True)
  41. showKeys = NestedBool(allow_none=True)
  42. spPr = Typed(expected_type=GraphicalProperties, allow_none=True)
  43. graphicalProperties = Alias('spPr')
  44. txPr = Typed(expected_type=RichText, allow_none=True)
  45. extLst = Typed(expected_type=ExtensionList, allow_none=True)
  46. __elements__ = ('showHorzBorder', 'showVertBorder', 'showOutline',
  47. 'showKeys', 'spPr', 'txPr')
  48. def __init__(self,
  49. showHorzBorder=None,
  50. showVertBorder=None,
  51. showOutline=None,
  52. showKeys=None,
  53. spPr=None,
  54. txPr=None,
  55. extLst=None,
  56. ):
  57. self.showHorzBorder = showHorzBorder
  58. self.showVertBorder = showVertBorder
  59. self.showOutline = showOutline
  60. self.showKeys = showKeys
  61. self.spPr = spPr
  62. self.txPr = txPr
  63. class PlotArea(Serialisable):
  64. tagname = "plotArea"
  65. layout = Typed(expected_type=Layout, allow_none=True)
  66. dTable = Typed(expected_type=DataTable, allow_none=True)
  67. spPr = Typed(expected_type=GraphicalProperties, allow_none=True)
  68. graphicalProperties = Alias("spPr")
  69. extLst = Typed(expected_type=ExtensionList, allow_none=True)
  70. # at least one chart
  71. _charts = MultiSequence()
  72. areaChart = MultiSequencePart(expected_type=AreaChart, store="_charts")
  73. area3DChart = MultiSequencePart(expected_type=AreaChart3D, store="_charts")
  74. lineChart = MultiSequencePart(expected_type=LineChart, store="_charts")
  75. line3DChart = MultiSequencePart(expected_type=LineChart3D, store="_charts")
  76. stockChart = MultiSequencePart(expected_type=StockChart, store="_charts")
  77. radarChart = MultiSequencePart(expected_type=RadarChart, store="_charts")
  78. scatterChart = MultiSequencePart(expected_type=ScatterChart, store="_charts")
  79. pieChart = MultiSequencePart(expected_type=PieChart, store="_charts")
  80. pie3DChart = MultiSequencePart(expected_type=PieChart3D, store="_charts")
  81. doughnutChart = MultiSequencePart(expected_type=DoughnutChart, store="_charts")
  82. barChart = MultiSequencePart(expected_type=BarChart, store="_charts")
  83. bar3DChart = MultiSequencePart(expected_type=BarChart3D, store="_charts")
  84. ofPieChart = MultiSequencePart(expected_type=ProjectedPieChart, store="_charts")
  85. surfaceChart = MultiSequencePart(expected_type=SurfaceChart, store="_charts")
  86. surface3DChart = MultiSequencePart(expected_type=SurfaceChart3D, store="_charts")
  87. bubbleChart = MultiSequencePart(expected_type=BubbleChart, store="_charts")
  88. # axes
  89. _axes = MultiSequence()
  90. valAx = MultiSequencePart(expected_type=NumericAxis, store="_axes")
  91. catAx = MultiSequencePart(expected_type=TextAxis, store="_axes")
  92. dateAx = MultiSequencePart(expected_type=DateAxis, store="_axes")
  93. serAx = MultiSequencePart(expected_type=SeriesAxis, store="_axes")
  94. __elements__ = ('layout', '_charts', '_axes', 'dTable', 'spPr')
  95. def __init__(self,
  96. layout=None,
  97. dTable=None,
  98. spPr=None,
  99. _charts=(),
  100. _axes=(),
  101. extLst=None,
  102. ):
  103. self.layout = layout
  104. self.dTable = dTable
  105. self.spPr = spPr
  106. self._charts = _charts
  107. self._axes = _axes
  108. def to_tree(self, tagname=None, idx=None, namespace=None):
  109. axIds = {ax.axId for ax in self._axes}
  110. for chart in self._charts:
  111. for id, axis in chart._axes.items():
  112. if id not in axIds:
  113. setattr(self, axis.tagname, axis)
  114. axIds.add(id)
  115. return super().to_tree(tagname)
  116. @classmethod
  117. def from_tree(cls, node):
  118. self = super().from_tree(node)
  119. axes = dict((axis.axId, axis) for axis in self._axes)
  120. for chart in self._charts:
  121. if isinstance(chart, (ScatterChart, BubbleChart)):
  122. x, y = (axes[axId] for axId in chart.axId)
  123. chart.x_axis = x
  124. chart.y_axis = y
  125. continue
  126. for axId in chart.axId:
  127. axis = axes.get(axId)
  128. if axis is None and isinstance(chart, _3DBase):
  129. # Series Axis can be optional
  130. chart.z_axis = None
  131. continue
  132. if axis.tagname in ("catAx", "dateAx"):
  133. chart.x_axis = axis
  134. elif axis.tagname == "valAx":
  135. chart.y_axis = axis
  136. elif axis.tagname == "serAx":
  137. chart.z_axis = axis
  138. return self