som_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import colorsys
  4. from dataclasses import dataclass
  5. from typing import List, Tuple
  6. import cv2
  7. import matplotlib as mpl
  8. import matplotlib.colors as mplc
  9. import numpy as np
  10. import pycocotools.mask as mask_utils
  11. def rgb_to_hex(rgb_color):
  12. """
  13. Convert a rgb color to hex color.
  14. Args:
  15. rgb_color (tuple/list of ints): RGB color in tuple or list format.
  16. Returns:
  17. str: Hex color.
  18. Example:
  19. ```
  20. >>> rgb_to_hex((255, 0, 244))
  21. '#ff00ff'
  22. ```
  23. """
  24. return "#" + "".join([hex(c)[2:].zfill(2) for c in rgb_color])
  25. # DEFAULT_COLOR_HEX_TO_NAME = {
  26. # rgb_to_hex((255, 0, 0)): "red",
  27. # rgb_to_hex((0, 255, 0)): "lime",
  28. # rgb_to_hex((0, 0, 255)): "blue",
  29. # rgb_to_hex((255, 255, 0)): "yellow",
  30. # rgb_to_hex((255, 0, 255)): "fuchsia",
  31. # rgb_to_hex((0, 255, 255)): "aqua",
  32. # rgb_to_hex((255, 165, 0)): "orange",
  33. # rgb_to_hex((128, 0, 128)): "purple",
  34. # rgb_to_hex((255, 215, 0)): "gold",
  35. # }
  36. # Assuming rgb_to_hex is a function that converts an (R, G, B) tuple to a hex string.
  37. # For example: def rgb_to_hex(rgb): return '#%02x%02x%02x' % rgb
  38. DEFAULT_COLOR_HEX_TO_NAME = {
  39. # The top 20 approved colors
  40. rgb_to_hex((255, 255, 0)): "yellow",
  41. rgb_to_hex((0, 255, 0)): "lime",
  42. rgb_to_hex((0, 255, 255)): "cyan",
  43. rgb_to_hex((255, 0, 255)): "magenta",
  44. rgb_to_hex((255, 0, 0)): "red",
  45. rgb_to_hex((255, 127, 0)): "orange",
  46. rgb_to_hex((127, 255, 0)): "chartreuse",
  47. rgb_to_hex((0, 255, 127)): "spring green",
  48. rgb_to_hex((255, 0, 127)): "rose",
  49. rgb_to_hex((127, 0, 255)): "violet",
  50. rgb_to_hex((192, 255, 0)): "electric lime",
  51. rgb_to_hex((255, 192, 0)): "vivid orange",
  52. rgb_to_hex((0, 255, 192)): "turquoise",
  53. rgb_to_hex((192, 0, 255)): "bright violet",
  54. rgb_to_hex((255, 0, 192)): "bright pink",
  55. rgb_to_hex((255, 64, 0)): "fiery orange",
  56. rgb_to_hex((64, 255, 0)): "bright chartreuse",
  57. rgb_to_hex((0, 255, 64)): "malachite",
  58. rgb_to_hex((64, 0, 255)): "deep violet",
  59. rgb_to_hex((255, 0, 64)): "hot pink",
  60. }
  61. DEFAULT_COLOR_PALETTE = list(DEFAULT_COLOR_HEX_TO_NAME.keys())
  62. def _validate_color_hex(color_hex: str):
  63. color_hex = color_hex.lstrip("#")
  64. if not all(c in "0123456789abcdefABCDEF" for c in color_hex):
  65. raise ValueError("Invalid characters in color hash")
  66. if len(color_hex) not in (3, 6):
  67. raise ValueError("Invalid length of color hash")
  68. # copied from https://github.com/roboflow/supervision/blob/c8f557af0c61b5c03392bad2cc36c8835598b1e1/supervision/draw/color.py
  69. @dataclass
  70. class Color:
  71. """
  72. Represents a color in RGB format.
  73. Attributes:
  74. r (int): Red channel.
  75. g (int): Green channel.
  76. b (int): Blue channel.
  77. """
  78. r: int
  79. g: int
  80. b: int
  81. @classmethod
  82. def from_hex(cls, color_hex: str):
  83. """
  84. Create a Color instance from a hex string.
  85. Args:
  86. color_hex (str): Hex string of the color.
  87. Returns:
  88. Color: Instance representing the color.
  89. Example:
  90. ```
  91. >>> Color.from_hex('#ff00ff')
  92. Color(r=255, g=0, b=255)
  93. ```
  94. """
  95. _validate_color_hex(color_hex)
  96. color_hex = color_hex.lstrip("#")
  97. if len(color_hex) == 3:
  98. color_hex = "".join(c * 2 for c in color_hex)
  99. r, g, b = (int(color_hex[i : i + 2], 16) for i in range(0, 6, 2))
  100. return cls(r, g, b)
  101. @classmethod
  102. def to_hex(cls, color):
  103. """
  104. Convert a Color instance to a hex string.
  105. Args:
  106. color (Color): Color instance of color.
  107. Returns:
  108. Color: a hex string.
  109. """
  110. return rgb_to_hex((color.r, color.g, color.b))
  111. def as_rgb(self) -> Tuple[int, int, int]:
  112. """
  113. Returns the color as an RGB tuple.
  114. Returns:
  115. Tuple[int, int, int]: RGB tuple.
  116. Example:
  117. ```
  118. >>> color.as_rgb()
  119. (255, 0, 255)
  120. ```
  121. """
  122. return self.r, self.g, self.b
  123. def as_bgr(self) -> Tuple[int, int, int]:
  124. """
  125. Returns the color as a BGR tuple.
  126. Returns:
  127. Tuple[int, int, int]: BGR tuple.
  128. Example:
  129. ```
  130. >>> color.as_bgr()
  131. (255, 0, 255)
  132. ```
  133. """
  134. return self.b, self.g, self.r
  135. @classmethod
  136. def white(cls):
  137. return Color.from_hex(color_hex="#ffffff")
  138. @classmethod
  139. def black(cls):
  140. return Color.from_hex(color_hex="#000000")
  141. @classmethod
  142. def red(cls):
  143. return Color.from_hex(color_hex="#ff0000")
  144. @classmethod
  145. def green(cls):
  146. return Color.from_hex(color_hex="#00ff00")
  147. @classmethod
  148. def blue(cls):
  149. return Color.from_hex(color_hex="#0000ff")
  150. @dataclass
  151. class ColorPalette:
  152. colors: List[Color]
  153. @classmethod
  154. def default(cls):
  155. """
  156. Returns a default color palette.
  157. Returns:
  158. ColorPalette: A ColorPalette instance with default colors.
  159. Example:
  160. ```
  161. >>> ColorPalette.default()
  162. ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
  163. ```
  164. """
  165. return ColorPalette.from_hex(color_hex_list=DEFAULT_COLOR_PALETTE)
  166. @classmethod
  167. def from_hex(cls, color_hex_list: List[str]):
  168. """
  169. Create a ColorPalette instance from a list of hex strings.
  170. Args:
  171. color_hex_list (List[str]): List of color hex strings.
  172. Returns:
  173. ColorPalette: A ColorPalette instance.
  174. Example:
  175. ```
  176. >>> ColorPalette.from_hex(['#ff0000', '#00ff00', '#0000ff'])
  177. ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
  178. ```
  179. """
  180. colors = [Color.from_hex(color_hex) for color_hex in color_hex_list]
  181. return cls(colors)
  182. def by_idx(self, idx: int) -> Color:
  183. """
  184. Return the color at a given index in the palette.
  185. Args:
  186. idx (int): Index of the color in the palette.
  187. Returns:
  188. Color: Color at the given index.
  189. Example:
  190. ```
  191. >>> color_palette.by_idx(1)
  192. Color(r=0, g=255, b=0)
  193. ```
  194. """
  195. if idx < 0:
  196. raise ValueError("idx argument should not be negative")
  197. idx = idx % len(self.colors)
  198. return self.colors[idx]
  199. def find_farthest_color(self, img_array):
  200. """
  201. Return the color that is the farthest from the given color.
  202. Args:
  203. img_array (np array): any *x3 np array, 3 is the RGB color channel.
  204. Returns:
  205. Color: Farthest color.
  206. """
  207. # Reshape the image array for broadcasting
  208. img_array = img_array.reshape((-1, 3))
  209. # Convert colors dictionary to a NumPy array
  210. color_values = np.array([[c.r, c.g, c.b] for c in self.colors])
  211. # Calculate the Euclidean distance between the colors and each pixel in the image
  212. # Broadcasting happens here: img_array shape is (num_pixels, 3), color_values shape is (num_colors, 3)
  213. distances = np.sqrt(
  214. np.sum((img_array[:, np.newaxis, :] - color_values) ** 2, axis=2)
  215. )
  216. # Average the distances for each color
  217. mean_distances = np.mean(distances, axis=0)
  218. # return the farthest color
  219. farthest_idx = np.argmax(mean_distances)
  220. farthest_color = self.colors[farthest_idx]
  221. farthest_color_hex = Color.to_hex(farthest_color)
  222. if farthest_color_hex in DEFAULT_COLOR_HEX_TO_NAME:
  223. farthest_color_name = DEFAULT_COLOR_HEX_TO_NAME[farthest_color_hex]
  224. else:
  225. farthest_color_name = "unknown"
  226. return farthest_color, farthest_color_name
  227. def draw_box(ax, box_coord, alpha=0.8, edge_color="g", line_style="-", linewidth=2.0):
  228. x0, y0, width, height = box_coord
  229. ax.add_patch(
  230. mpl.patches.Rectangle(
  231. (x0, y0),
  232. width,
  233. height,
  234. fill=False,
  235. edgecolor=edge_color,
  236. linewidth=linewidth,
  237. alpha=alpha,
  238. linestyle=line_style,
  239. )
  240. )
  241. def draw_text(
  242. ax,
  243. text,
  244. position,
  245. font_size=None,
  246. color="g",
  247. horizontal_alignment="left",
  248. rotation=0,
  249. ):
  250. if not font_size:
  251. font_size = mpl.rcParams["font.size"]
  252. color = np.maximum(list(mplc.to_rgb(color)), 0.2)
  253. color[np.argmax(color)] = max(0.8, np.max(color))
  254. x, y = position
  255. ax.text(
  256. x,
  257. y,
  258. text,
  259. size=font_size,
  260. family="sans-serif",
  261. bbox={"facecolor": "none", "alpha": 0.5, "pad": 0.7, "edgecolor": "none"},
  262. verticalalignment="top",
  263. horizontalalignment=horizontal_alignment,
  264. color=color,
  265. rotation=rotation,
  266. )
  267. def draw_mask(
  268. ax, rle, color, show_holes=True, alpha=0.15, upsample_factor=1.0, rle_upsampled=None
  269. ):
  270. if isinstance(rle, dict):
  271. mask = mask_utils.decode(rle)
  272. elif isinstance(rle, np.ndarray):
  273. mask = rle
  274. else:
  275. raise ValueError(f"Unsupported type for rle: {type(rle)}")
  276. mask_upsampled = None
  277. if upsample_factor > 1.0 and show_holes:
  278. assert rle_upsampled is not None
  279. if isinstance(rle_upsampled, dict):
  280. mask_upsampled = mask_utils.decode(rle_upsampled)
  281. elif isinstance(rle_upsampled, np.ndarray):
  282. mask_upsampled = rle_upsampled
  283. else:
  284. raise ValueError(f"Unsupported type for rle: {type(rle)}")
  285. if show_holes:
  286. if mask_upsampled is None:
  287. mask_upsampled = mask
  288. h, w = mask_upsampled.shape
  289. mask_img = np.zeros((h, w, 4))
  290. mask_img[:, :, :-1] = color[np.newaxis, np.newaxis, :]
  291. mask_img[:, :, -1] = mask_upsampled * alpha
  292. ax.imshow(mask_img)
  293. *_, contours, _ = cv2.findContours(
  294. mask.astype(np.uint8).copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
  295. )
  296. upsampled_contours = [(cont + 0.5) * upsample_factor - 0.5 for cont in contours]
  297. facecolor = (0, 0, 0, 0) if show_holes else color
  298. if alpha > 0.8:
  299. edge_color = _change_color_brightness(color, brightness_factor=-0.7)
  300. else:
  301. edge_color = color
  302. for cont in upsampled_contours:
  303. polygon = mpl.patches.Polygon(
  304. [el[0] for el in cont],
  305. edgecolor=edge_color,
  306. linewidth=2.0,
  307. facecolor=facecolor,
  308. )
  309. ax.add_patch(polygon)
  310. def _change_color_brightness(color, brightness_factor):
  311. """
  312. Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
  313. less or more saturation than the original color.
  314. Args:
  315. color: color of the polygon. Refer to `matplotlib.colors` for a full list of
  316. formats that are accepted.
  317. brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
  318. 0 will correspond to no change, a factor in [-1.0, 0) range will result in
  319. a darker color and a factor in (0, 1.0] range will result in a lighter color.
  320. Returns:
  321. modified_color (tuple[double]): a tuple containing the RGB values of the
  322. modified color. Each value in the tuple is in the [0.0, 1.0] range.
  323. """
  324. assert brightness_factor >= -1.0 and brightness_factor <= 1.0
  325. color = mplc.to_rgb(color)
  326. polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
  327. modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
  328. modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
  329. modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
  330. modified_color = colorsys.hls_to_rgb(
  331. polygon_color[0], modified_lightness, polygon_color[2]
  332. )
  333. return modified_color