visualizer.py 62 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import colorsys
  4. import logging
  5. import math
  6. import random
  7. from enum import Enum, unique
  8. import cv2
  9. import matplotlib as mpl
  10. import matplotlib.colors as mplc
  11. import matplotlib.figure as mplfigure
  12. import numpy as np
  13. import pycocotools.mask as mask_util
  14. import torch
  15. from iopath.common.file_io import PathManager
  16. from matplotlib.backends.backend_agg import FigureCanvasAgg
  17. from PIL import Image
  18. from .boxes import Boxes, BoxMode
  19. from .color_map import random_color
  20. from .keypoints import Keypoints
  21. from .masks import BitMasks, PolygonMasks
  22. from .rotated_boxes import RotatedBoxes
  23. logger = logging.getLogger(__name__)
  24. __all__ = ["ColorMode", "VisImage", "Visualizer"]
  25. _SMALL_OBJECT_AREA_THRESH = 1000
  26. _LARGE_MASK_AREA_THRESH = 120000
  27. _OFF_WHITE = (1.0, 1.0, 240.0 / 255)
  28. _BLACK = (0, 0, 0)
  29. _RED = (1.0, 0, 0)
  30. _KEYPOINT_THRESHOLD = 0.05
  31. @unique
  32. class ColorMode(Enum):
  33. """
  34. Enum of different color modes to use for instance visualizations.
  35. """
  36. IMAGE = 0
  37. """
  38. Picks a random color for every instance and overlay segmentations with low opacity.
  39. """
  40. SEGMENTATION = 1
  41. """
  42. Let instances of the same category have similar colors
  43. (from metadata.thing_colors), and overlay them with
  44. high opacity. This provides more attention on the quality of segmentation.
  45. """
  46. IMAGE_BW = 2
  47. """
  48. Same as IMAGE, but convert all areas without masks to gray-scale.
  49. Only available for drawing per-instance mask predictions.
  50. """
  51. class GenericMask:
  52. """
  53. Attribute:
  54. polygons (list[ndarray]): list[ndarray]: polygons for this mask.
  55. Each ndarray has format [x, y, x, y, ...]
  56. mask (ndarray): a binary mask
  57. """
  58. def __init__(self, mask_or_polygons, height, width):
  59. self._mask = self._polygons = self._has_holes = None
  60. self.height = height
  61. self.width = width
  62. m = mask_or_polygons
  63. if isinstance(m, dict):
  64. # RLEs
  65. assert "counts" in m and "size" in m
  66. if isinstance(m["counts"], list): # uncompressed RLEs
  67. h, w = m["size"]
  68. assert h == height and w == width
  69. m = mask_util.frPyObjects(m, h, w)
  70. self._mask = mask_util.decode(m)[:, :]
  71. return
  72. if isinstance(m, list): # list[ndarray]
  73. self._polygons = [np.asarray(x).reshape(-1) for x in m]
  74. return
  75. if isinstance(m, np.ndarray): # assumed to be a binary mask
  76. assert m.shape[1] != 2, m.shape
  77. assert m.shape == (
  78. height,
  79. width,
  80. ), f"mask shape: {m.shape}, target dims: {height}, {width}"
  81. self._mask = m.astype("uint8")
  82. return
  83. raise ValueError(
  84. "GenericMask cannot handle object {} of type '{}'".format(m, type(m))
  85. )
  86. @property
  87. def mask(self):
  88. if self._mask is None:
  89. self._mask = self.polygons_to_mask(self._polygons)
  90. return self._mask
  91. @property
  92. def polygons(self):
  93. if self._polygons is None:
  94. self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
  95. return self._polygons
  96. @property
  97. def has_holes(self):
  98. if self._has_holes is None:
  99. if self._mask is not None:
  100. self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
  101. else:
  102. self._has_holes = (
  103. False # if original format is polygon, does not have holes
  104. )
  105. return self._has_holes
  106. def mask_to_polygons(self, mask):
  107. # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
  108. # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
  109. # Internal contours (holes) are placed in hierarchy-2.
  110. # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
  111. mask = np.ascontiguousarray(
  112. mask
  113. ) # some versions of cv2 does not support incontiguous arr
  114. res = cv2.findContours(
  115. mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
  116. )
  117. hierarchy = res[-1]
  118. if hierarchy is None: # empty mask
  119. return [], False
  120. has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
  121. res = res[-2]
  122. res = [x.flatten() for x in res]
  123. # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
  124. # We add 0.5 to turn them into real-value coordinate space. A better solution
  125. # would be to first +0.5 and then dilate the returned polygon by 0.5.
  126. res = [x + 0.5 for x in res if len(x) >= 6]
  127. return res, has_holes
  128. def polygons_to_mask(self, polygons):
  129. rle = mask_util.frPyObjects(polygons, self.height, self.width)
  130. rle = mask_util.merge(rle)
  131. return mask_util.decode(rle)[:, :]
  132. def area(self):
  133. return self.mask.sum()
  134. def bbox(self):
  135. p = mask_util.frPyObjects(self.polygons, self.height, self.width)
  136. p = mask_util.merge(p)
  137. bbox = mask_util.toBbox(p)
  138. bbox[2] += bbox[0]
  139. bbox[3] += bbox[1]
  140. return bbox
  141. class _PanopticPrediction:
  142. """
  143. Unify different panoptic annotation/prediction formats
  144. """
  145. def __init__(self, panoptic_seg, segments_info, metadata=None):
  146. if segments_info is None:
  147. assert metadata is not None
  148. # If "segments_info" is None, we assume "panoptic_img" is a
  149. # H*W int32 image storing the panoptic_id in the format of
  150. # category_id * label_divisor + instance_id. We reserve -1 for
  151. # VOID label.
  152. label_divisor = metadata.label_divisor
  153. segments_info = []
  154. for panoptic_label in np.unique(panoptic_seg.numpy()):
  155. if panoptic_label == -1:
  156. # VOID region.
  157. continue
  158. pred_class = panoptic_label // label_divisor
  159. isthing = (
  160. pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
  161. )
  162. segments_info.append(
  163. {
  164. "id": int(panoptic_label),
  165. "category_id": int(pred_class),
  166. "isthing": bool(isthing),
  167. }
  168. )
  169. del metadata
  170. self._seg = panoptic_seg
  171. self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
  172. segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
  173. areas = areas.numpy()
  174. sorted_idxs = np.argsort(-areas)
  175. self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
  176. self._seg_ids = self._seg_ids.tolist()
  177. for sid, area in zip(self._seg_ids, self._seg_areas):
  178. if sid in self._sinfo:
  179. self._sinfo[sid]["area"] = float(area)
  180. def non_empty_mask(self):
  181. """
  182. Returns:
  183. (H, W) array, a mask for all pixels that have a prediction
  184. """
  185. empty_ids = []
  186. for id in self._seg_ids:
  187. if id not in self._sinfo:
  188. empty_ids.append(id)
  189. if len(empty_ids) == 0:
  190. return np.zeros(self._seg.shape, dtype=np.uint8)
  191. assert len(empty_ids) == 1, (
  192. ">1 ids corresponds to no labels. This is currently not supported"
  193. )
  194. return (self._seg != empty_ids[0]).numpy().astype(np.bool)
  195. def semantic_masks(self):
  196. for sid in self._seg_ids:
  197. sinfo = self._sinfo.get(sid)
  198. if sinfo is None or sinfo["isthing"]:
  199. # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
  200. continue
  201. yield (self._seg == sid).numpy().astype(np.bool), sinfo
  202. def instance_masks(self):
  203. for sid in self._seg_ids:
  204. sinfo = self._sinfo.get(sid)
  205. if sinfo is None or not sinfo["isthing"]:
  206. continue
  207. mask = (self._seg == sid).numpy().astype(np.bool)
  208. if mask.sum() > 0:
  209. yield mask, sinfo
  210. def _create_text_labels(classes, scores, class_names, is_crowd=None):
  211. """
  212. Args:
  213. classes (list[int] or None):
  214. scores (list[float] or None):
  215. class_names (list[str] or None):
  216. is_crowd (list[bool] or None):
  217. Returns:
  218. list[str] or None
  219. """
  220. labels = None
  221. if classes is not None:
  222. if class_names is not None and len(class_names) > 0:
  223. labels = [class_names[i] for i in classes]
  224. else:
  225. labels = [str(i) for i in classes]
  226. if scores is not None:
  227. if labels is None:
  228. labels = ["{:.0f}%".format(s * 100) for s in scores]
  229. else:
  230. labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
  231. if labels is not None and is_crowd is not None:
  232. labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
  233. return labels
  234. class VisImage:
  235. def __init__(self, img, scale=1.0):
  236. """
  237. Args:
  238. img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
  239. scale (float): scale the input image
  240. """
  241. self.img = img
  242. self.scale = scale
  243. self.width, self.height = img.shape[1], img.shape[0]
  244. self._setup_figure(img)
  245. def _setup_figure(self, img):
  246. """
  247. Args:
  248. Same as in :meth:`__init__()`.
  249. Returns:
  250. fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
  251. ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
  252. """
  253. fig = mplfigure.Figure(frameon=False)
  254. self.dpi = fig.get_dpi()
  255. # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
  256. # (https://github.com/matplotlib/matplotlib/issues/15363)
  257. fig.set_size_inches(
  258. (self.width * self.scale + 1e-2) / self.dpi,
  259. (self.height * self.scale + 1e-2) / self.dpi,
  260. )
  261. self.canvas = FigureCanvasAgg(fig)
  262. # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
  263. ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
  264. ax.axis("off")
  265. self.fig = fig
  266. self.ax = ax
  267. self.reset_image(img)
  268. def reset_image(self, img):
  269. """
  270. Args:
  271. img: same as in __init__
  272. """
  273. img = img.astype("uint8")
  274. self.ax.imshow(
  275. img, extent=(0, self.width, self.height, 0), interpolation="nearest"
  276. )
  277. def save(self, filepath):
  278. """
  279. Args:
  280. filepath (str): a string that contains the absolute path, including the file name, where
  281. the visualized image will be saved.
  282. """
  283. self.fig.savefig(filepath)
  284. def get_image(self):
  285. """
  286. Returns:
  287. ndarray:
  288. the visualized image of shape (H, W, 3) (RGB) in uint8 type.
  289. The shape is scaled w.r.t the input image using the given `scale` argument.
  290. """
  291. canvas = self.canvas
  292. s, (width, height) = canvas.print_to_buffer()
  293. # buf = io.BytesIO() # works for cairo backend
  294. # canvas.print_rgba(buf)
  295. # width, height = self.width, self.height
  296. # s = buf.getvalue()
  297. buffer = np.frombuffer(s, dtype="uint8")
  298. img_rgba = buffer.reshape(height, width, 4)
  299. rgb, alpha = np.split(img_rgba, [3], axis=2)
  300. return rgb.astype("uint8")
  301. class Visualizer:
  302. """
  303. Visualizer that draws data about detection/segmentation on images.
  304. It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
  305. that draw primitive objects to images, as well as high-level wrappers like
  306. `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
  307. that draw composite data in some pre-defined style.
  308. Note that the exact visualization style for the high-level wrappers are subject to change.
  309. Style such as color, opacity, label contents, visibility of labels, or even the visibility
  310. of objects themselves (e.g. when the object is too small) may change according
  311. to different heuristics, as long as the results still look visually reasonable.
  312. To obtain a consistent style, you can implement custom drawing functions with the
  313. abovementioned primitive methods instead. If you need more customized visualization
  314. styles, you can process the data yourself following their format documented in
  315. tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
  316. intend to satisfy everyone's preference on drawing styles.
  317. This visualizer focuses on high rendering quality rather than performance. It is not
  318. designed to be used for real-time applications.
  319. """
  320. def __init__(
  321. self,
  322. img_rgb,
  323. metadata=None,
  324. scale=1.0,
  325. instance_mode=ColorMode.IMAGE,
  326. font_size_multiplier=1.3,
  327. boarder_width_multiplier=1.5,
  328. ):
  329. """
  330. Args:
  331. img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
  332. the height and width of the image respectively. C is the number of
  333. color channels. The image is required to be in RGB format since that
  334. is a requirement of the Matplotlib library. The image is also expected
  335. to be in the range [0, 255].
  336. metadata (Metadata): dataset metadata (e.g. class names and colors)
  337. instance_mode (ColorMode): defines one of the pre-defined style for drawing
  338. instances on an image.
  339. """
  340. self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
  341. self.boarder_width_multiplier = boarder_width_multiplier
  342. # if metadata is None:
  343. # metadata = MetadataCatalog.get("__nonexist__")
  344. # self.metadata = metadata
  345. self.output = VisImage(self.img, scale=scale)
  346. self.cpu_device = torch.device("cpu")
  347. # too small texts are useless, therefore clamp to 9
  348. self._default_font_size = (
  349. max(np.sqrt(self.output.height * self.output.width) // 60, 15 // scale)
  350. * font_size_multiplier
  351. )
  352. # self._default_font_size = 18
  353. self._instance_mode = instance_mode
  354. self.keypoint_threshold = _KEYPOINT_THRESHOLD
  355. import matplotlib.colors as mcolors
  356. css4_colors = mcolors.CSS4_COLORS
  357. self.color_proposals = [
  358. list(mcolors.hex2color(color)) for color in css4_colors.values()
  359. ]
  360. def draw_instance_predictions(self, predictions):
  361. """
  362. Draw instance-level prediction results on an image.
  363. Args:
  364. predictions (Instances): the output of an instance detection/segmentation
  365. model. Following fields will be used to draw:
  366. "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
  367. Returns:
  368. output (VisImage): image object with visualizations.
  369. """
  370. boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
  371. scores = predictions.scores if predictions.has("scores") else None
  372. classes = (
  373. predictions.pred_classes.tolist()
  374. if predictions.has("pred_classes")
  375. else None
  376. )
  377. labels = _create_text_labels(
  378. classes, scores, self.metadata.get("thing_classes", None)
  379. )
  380. keypoints = (
  381. predictions.pred_keypoints if predictions.has("pred_keypoints") else None
  382. )
  383. keep = (scores > 0.5).cpu()
  384. boxes = boxes[keep]
  385. scores = scores[keep]
  386. classes = np.array(classes)
  387. classes = classes[np.array(keep)]
  388. labels = np.array(labels)
  389. labels = labels[np.array(keep)]
  390. if predictions.has("pred_masks"):
  391. masks = np.asarray(predictions.pred_masks)
  392. masks = masks[np.array(keep)]
  393. masks = [
  394. GenericMask(x, self.output.height, self.output.width) for x in masks
  395. ]
  396. else:
  397. masks = None
  398. if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(
  399. "thing_colors"
  400. ):
  401. # if self.metadata.get("thing_colors"):
  402. colors = [
  403. self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
  404. for c in classes
  405. ]
  406. alpha = 0.4
  407. else:
  408. colors = None
  409. alpha = 0.4
  410. if self._instance_mode == ColorMode.IMAGE_BW:
  411. self.output.reset_image(
  412. self._create_grayscale_image(
  413. (predictions.pred_masks.any(dim=0) > 0).numpy()
  414. if predictions.has("pred_masks")
  415. else None
  416. )
  417. )
  418. alpha = 0.3
  419. self.overlay_instances(
  420. masks=masks,
  421. boxes=boxes,
  422. labels=labels,
  423. keypoints=keypoints,
  424. assigned_colors=colors,
  425. alpha=alpha,
  426. )
  427. return self.output
  428. def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.7):
  429. """
  430. Draw semantic segmentation predictions/labels.
  431. Args:
  432. sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
  433. Each value is the integer label of the pixel.
  434. area_threshold (int): segments with less than `area_threshold` are not drawn.
  435. alpha (float): the larger it is, the more opaque the segmentations are.
  436. Returns:
  437. output (VisImage): image object with visualizations.
  438. """
  439. if isinstance(sem_seg, torch.Tensor):
  440. sem_seg = sem_seg.numpy()
  441. labels, areas = np.unique(sem_seg, return_counts=True)
  442. sorted_idxs = np.argsort(-areas).tolist()
  443. labels = labels[sorted_idxs]
  444. for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
  445. try:
  446. mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
  447. except (AttributeError, IndexError):
  448. mask_color = None
  449. binary_mask = (sem_seg == label).astype(np.uint8)
  450. text = self.metadata.stuff_classes[label]
  451. self.draw_binary_mask(
  452. binary_mask,
  453. color=mask_color,
  454. edge_color=_OFF_WHITE,
  455. text=text,
  456. alpha=alpha,
  457. area_threshold=area_threshold,
  458. )
  459. return self.output
  460. def draw_panoptic_seg(
  461. self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7
  462. ):
  463. """
  464. Draw panoptic prediction annotations or results.
  465. Args:
  466. panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
  467. segment.
  468. segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
  469. If it is a ``list[dict]``, each dict contains keys "id", "category_id".
  470. If None, category id of each pixel is computed by
  471. ``pixel // metadata.label_divisor``.
  472. area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
  473. Returns:
  474. output (VisImage): image object with visualizations.
  475. """
  476. pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
  477. if self._instance_mode == ColorMode.IMAGE_BW:
  478. self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
  479. # draw mask for all semantic segments first i.e. "stuff"
  480. for mask, sinfo in pred.semantic_masks():
  481. category_idx = sinfo["category_id"]
  482. try:
  483. mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
  484. except AttributeError:
  485. mask_color = None
  486. text = (
  487. self.metadata.stuff_classes[category_idx]
  488. .replace("-other", "")
  489. .replace("-merged", "")
  490. )
  491. self.draw_binary_mask(
  492. mask,
  493. color=mask_color,
  494. edge_color=_OFF_WHITE,
  495. text=text,
  496. alpha=alpha,
  497. area_threshold=area_threshold,
  498. )
  499. # draw mask for all instances second
  500. all_instances = list(pred.instance_masks())
  501. if len(all_instances) == 0:
  502. return self.output
  503. masks, sinfo = list(zip(*all_instances))
  504. category_ids = [x["category_id"] for x in sinfo]
  505. try:
  506. scores = [x["score"] for x in sinfo]
  507. except KeyError:
  508. scores = None
  509. class_names = [
  510. name.replace("-other", "").replace("-merged", "")
  511. for name in self.metadata.thing_classes
  512. ]
  513. labels = _create_text_labels(
  514. category_ids, scores, class_names, [x.get("iscrowd", 0) for x in sinfo]
  515. )
  516. try:
  517. colors = [
  518. self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
  519. for c in category_ids
  520. ]
  521. except AttributeError:
  522. colors = None
  523. self.overlay_instances(
  524. masks=masks, labels=labels, assigned_colors=colors, alpha=alpha
  525. )
  526. return self.output
  527. draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
  528. def draw_dataset_dict(self, dic):
  529. """
  530. Draw annotations/segmentaions in Detectron2 Dataset format.
  531. Args:
  532. dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
  533. Returns:
  534. output (VisImage): image object with visualizations.
  535. """
  536. annos = dic.get("annotations", None)
  537. if annos:
  538. if "segmentation" in annos[0]:
  539. masks = [x["segmentation"] for x in annos]
  540. else:
  541. masks = None
  542. if "keypoints" in annos[0]:
  543. keypts = [x["keypoints"] for x in annos]
  544. keypts = np.array(keypts).reshape(len(annos), -1, 3)
  545. else:
  546. keypts = None
  547. boxes = [
  548. (
  549. BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
  550. if len(x["bbox"]) == 4
  551. else x["bbox"]
  552. )
  553. for x in annos
  554. ]
  555. colors = None
  556. category_ids = [x["category_id"] for x in annos]
  557. if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(
  558. "thing_colors"
  559. ):
  560. colors = [
  561. self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
  562. for c in category_ids
  563. ]
  564. names = self.metadata.get("thing_classes", None)
  565. labels = _create_text_labels(
  566. category_ids,
  567. scores=None,
  568. class_names=names,
  569. is_crowd=[x.get("iscrowd", 0) for x in annos],
  570. )
  571. self.overlay_instances(
  572. labels=labels,
  573. boxes=boxes,
  574. masks=masks,
  575. keypoints=keypts,
  576. assigned_colors=colors,
  577. )
  578. sem_seg = dic.get("sem_seg", None)
  579. if sem_seg is None and "sem_seg_file_name" in dic:
  580. with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
  581. sem_seg = Image.open(f)
  582. sem_seg = np.asarray(sem_seg, dtype="uint8")
  583. if sem_seg is not None:
  584. self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.4)
  585. pan_seg = dic.get("pan_seg", None)
  586. if pan_seg is None and "pan_seg_file_name" in dic:
  587. with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
  588. pan_seg = Image.open(f)
  589. pan_seg = np.asarray(pan_seg)
  590. from panopticapi.utils import rgb2id
  591. pan_seg = rgb2id(pan_seg)
  592. if pan_seg is not None:
  593. segments_info = dic["segments_info"]
  594. pan_seg = torch.tensor(pan_seg)
  595. self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.7)
  596. return self.output
  597. def overlay_instances(
  598. self,
  599. *,
  600. boxes=None,
  601. labels=None,
  602. masks=None,
  603. keypoints=None,
  604. assigned_colors=None,
  605. binary_masks=None,
  606. alpha=0.5,
  607. label_mode="1",
  608. ):
  609. """
  610. Args:
  611. boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
  612. or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
  613. or a :class:`RotatedBoxes`,
  614. or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
  615. for the N objects in a single image,
  616. labels (list[str]): the text to be displayed for each instance.
  617. masks (masks-like object): Supported types are:
  618. * :class:`detectron2.structures.PolygonMasks`,
  619. :class:`detectron2.structures.BitMasks`.
  620. * list[list[ndarray]]: contains the segmentation masks for all objects in one image.
  621. The first level of the list corresponds to individual instances. The second
  622. level to all the polygon that compose the instance, and the third level
  623. to the polygon coordinates. The third level should have the format of
  624. [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
  625. * list[ndarray]: each ndarray is a binary mask of shape (H, W).
  626. * list[dict]: each dict is a COCO-style RLE.
  627. keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
  628. where the N is the number of instances and K is the number of keypoints.
  629. The last dimension corresponds to (x, y, visibility or score).
  630. assigned_colors (list[matplotlib.colors]): a list of colors, where each color
  631. corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
  632. for full list of formats that the colors are accepted in.
  633. Returns:
  634. output (VisImage): image object with visualizations.
  635. """
  636. num_instances = 0
  637. if boxes is not None:
  638. boxes = self._convert_boxes(boxes)
  639. num_instances = len(boxes)
  640. if masks is not None:
  641. masks = self._convert_masks(masks)
  642. if num_instances:
  643. assert len(masks) == num_instances
  644. else:
  645. num_instances = len(masks)
  646. if keypoints is not None:
  647. if num_instances:
  648. assert len(keypoints) == num_instances
  649. else:
  650. num_instances = len(keypoints)
  651. keypoints = self._convert_keypoints(keypoints)
  652. if labels is not None:
  653. assert len(labels) == num_instances
  654. if assigned_colors is None:
  655. assigned_colors = [
  656. random_color(rgb=True, maximum=1) for _ in range(num_instances)
  657. ]
  658. if num_instances == 0:
  659. return labels, [], []
  660. if boxes is not None and boxes.shape[1] == 5:
  661. return self.overlay_rotated_instances(
  662. boxes=boxes, labels=labels, assigned_colors=assigned_colors
  663. )
  664. # Display in largest to smallest order to reduce occlusion.
  665. areas = None
  666. if boxes is not None:
  667. areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
  668. elif masks is not None:
  669. areas = np.asarray([x.area() for x in masks])
  670. # if areas is not None:
  671. # # sorted_idxs = np.argsort(areas).tolist()
  672. # sorted_idxs = np.argsort(-areas).tolist()
  673. # # Re-order overlapped instances in descending order.
  674. # boxes = boxes[sorted_idxs] if boxes is not None else None
  675. # labels = [labels[k] for k in sorted_idxs] if labels is not None else None
  676. # masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
  677. # binary_masks = (
  678. # [binary_masks[idx] for idx in sorted_idxs]
  679. # if binary_masks is not None
  680. # else None
  681. # )
  682. # assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
  683. # keypoints = keypoints[sorted_idxs] if keypoints is not None else None
  684. marks = []
  685. marks_position = []
  686. added_positions = set()
  687. for i in range(num_instances):
  688. color = assigned_colors[i]
  689. if boxes is not None:
  690. self.draw_box(boxes[i], alpha=1, edge_color=color)
  691. if binary_masks is None:
  692. # draw number for non-mask instances
  693. mark = self._draw_number_in_box(
  694. boxes[i], i + 1, color=color, label_mode=label_mode
  695. )
  696. marks.append(mark)
  697. if binary_masks is not None:
  698. mark, mask_position = self._draw_number_in_mask(
  699. binary_mask=binary_masks[i].astype("uint8"),
  700. text=i + 1,
  701. color=color,
  702. added_positions=added_positions,
  703. label_mode=label_mode,
  704. )
  705. marks.append(mark)
  706. marks_position.append(mask_position)
  707. self.draw_binary_mask(
  708. binary_masks[i],
  709. color=color,
  710. edge_color=_OFF_WHITE,
  711. alpha=alpha,
  712. )
  713. if masks is not None:
  714. for segment in masks[i].polygons:
  715. self.draw_polygon(
  716. segment.reshape(-1, 2), color, alpha=0
  717. ) # alpha=0 so holes in masks are not colored
  718. # draw keypoints
  719. if keypoints is not None:
  720. for keypoints_per_instance in keypoints:
  721. self.draw_and_connect_keypoints(keypoints_per_instance)
  722. # return labels, marks, sorted_idxs, marks_position
  723. return labels, marks, marks_position
  724. def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
  725. """
  726. Args:
  727. boxes (ndarray): an Nx5 numpy array of
  728. (x_center, y_center, width, height, angle_degrees) format
  729. for the N objects in a single image.
  730. labels (list[str]): the text to be displayed for each instance.
  731. assigned_colors (list[matplotlib.colors]): a list of colors, where each color
  732. corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
  733. for full list of formats that the colors are accepted in.
  734. Returns:
  735. output (VisImage): image object with visualizations.
  736. """
  737. num_instances = len(boxes)
  738. if assigned_colors is None:
  739. assigned_colors = [
  740. random_color(rgb=True, maximum=1) for _ in range(num_instances)
  741. ]
  742. if num_instances == 0:
  743. return self.output
  744. # Display in largest to smallest order to reduce occlusion.
  745. if boxes is not None:
  746. areas = boxes[:, 2] * boxes[:, 3]
  747. sorted_idxs = np.argsort(-areas).tolist()
  748. # Re-order overlapped instances in descending order.
  749. boxes = boxes[sorted_idxs]
  750. labels = [labels[k] for k in sorted_idxs] if labels is not None else None
  751. colors = [assigned_colors[idx] for idx in sorted_idxs]
  752. for i in range(num_instances):
  753. self.draw_rotated_box_with_label(
  754. boxes[i],
  755. edge_color=colors[i],
  756. label=labels[i] if labels is not None else None,
  757. )
  758. return self.output
  759. def draw_and_connect_keypoints(self, keypoints):
  760. """
  761. Draws keypoints of an instance and follows the rules for keypoint connections
  762. to draw lines between appropriate keypoints. This follows color heuristics for
  763. line color.
  764. Args:
  765. keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
  766. and the last dimension corresponds to (x, y, probability).
  767. Returns:
  768. output (VisImage): image object with visualizations.
  769. """
  770. visible = {}
  771. keypoint_names = self.metadata.get("keypoint_names")
  772. for idx, keypoint in enumerate(keypoints):
  773. # draw keypoint
  774. x, y, prob = keypoint
  775. if prob > self.keypoint_threshold:
  776. self.draw_circle((x, y), color=_RED)
  777. if keypoint_names:
  778. keypoint_name = keypoint_names[idx]
  779. visible[keypoint_name] = (x, y)
  780. if self.metadata.get("keypoint_connection_rules"):
  781. for kp0, kp1, color in self.metadata.keypoint_connection_rules:
  782. if kp0 in visible and kp1 in visible:
  783. x0, y0 = visible[kp0]
  784. x1, y1 = visible[kp1]
  785. color = tuple(x / 255.0 for x in color)
  786. self.draw_line([x0, x1], [y0, y1], color=color)
  787. # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
  788. # Note that this strategy is specific to person keypoints.
  789. # For other keypoints, it should just do nothing
  790. try:
  791. ls_x, ls_y = visible["left_shoulder"]
  792. rs_x, rs_y = visible["right_shoulder"]
  793. mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
  794. except KeyError:
  795. pass
  796. else:
  797. # draw line from nose to mid-shoulder
  798. nose_x, nose_y = visible.get("nose", (None, None))
  799. if nose_x is not None:
  800. self.draw_line(
  801. [nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED
  802. )
  803. try:
  804. # draw line from mid-shoulder to mid-hip
  805. lh_x, lh_y = visible["left_hip"]
  806. rh_x, rh_y = visible["right_hip"]
  807. except KeyError:
  808. pass
  809. else:
  810. mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
  811. self.draw_line(
  812. [mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED
  813. )
  814. return self.output
  815. def mask_dims_from_binary(self, binary_mask):
  816. ind_y, ind_x = np.where(binary_mask == 1)
  817. min_ind_x = np.min(ind_x)
  818. max_ind_x = np.max(ind_x)
  819. min_ind_y = np.min(ind_y)
  820. max_ind_y = np.max(ind_y)
  821. return (max_ind_x - min_ind_x), (max_ind_y - min_ind_y)
  822. def reposition_label(self, position, cur, binary_mask, move_count):
  823. img_width, img_height = self.output.width, self.output.height
  824. mask_width, mask_height = self.mask_dims_from_binary(binary_mask)
  825. # set resposition thresholds
  826. mask_width_limit, mask_height_limit = (
  827. 25,
  828. 25,
  829. ) # limit for width and height size for object covering
  830. location_diff_threshold = 15 # limit for the distance between two labels
  831. x_boundry_limit, y_boundry_limit = (
  832. 20,
  833. 20,
  834. ) # limit for the distancing the label from edges
  835. offset_x = 15 # move in x direction
  836. offset_y = 15 # move in y direction
  837. x1, y1 = position
  838. if (
  839. mask_width < mask_width_limit
  840. and mask_height < mask_height_limit
  841. and move_count == 0
  842. ):
  843. move_x = offset_x if offset_x + x1 < img_width else -offset_x
  844. move_y = offset_y if offset_y + y1 < img_height else -offset_y
  845. return (True, move_x, move_y)
  846. for x2, y2 in cur:
  847. if abs(x1 - x2) + abs(y1 - y2) < location_diff_threshold:
  848. move_x = offset_x if x1 >= x2 else -offset_x
  849. move_y = offset_y if y1 >= y2 else -offset_y
  850. move_x = (
  851. 0
  852. if x1 + move_x > img_width - x_boundry_limit
  853. or x1 + move_x < x_boundry_limit
  854. else move_x
  855. )
  856. move_y = (
  857. 0
  858. if y1 + move_y > img_height - y_boundry_limit
  859. or y1 + move_y < y_boundry_limit
  860. else move_y
  861. )
  862. return (
  863. True,
  864. move_x,
  865. move_y,
  866. )
  867. return (False, 0, 0)
  868. def locate_label_position(self, original_position, added_positions, binary_mask):
  869. if added_positions is None or binary_mask is None:
  870. return original_position
  871. x, y = original_position
  872. move_count = 0
  873. reposition, x_move, y_move = self.reposition_label(
  874. (x, y), added_positions, binary_mask, move_count
  875. )
  876. while reposition and move_count < 10:
  877. x += x_move
  878. y += y_move
  879. move_count += 1
  880. reposition, x_move, y_move = self.reposition_label(
  881. (x, y), added_positions, binary_mask, move_count
  882. )
  883. added_positions.add((x, y))
  884. return x, y
  885. """
  886. Primitive drawing functions:
  887. """
  888. def draw_text(
  889. self,
  890. text,
  891. position,
  892. added_positions=None,
  893. binary_mask=None,
  894. *,
  895. font_size=None,
  896. color="g",
  897. horizontal_alignment="center",
  898. rotation=0,
  899. ):
  900. """
  901. Args:
  902. text (str): class label
  903. position (tuple): a tuple of the x and y coordinates to place text on image.
  904. font_size (int, optional): font of the text. If not provided, a font size
  905. proportional to the image width is calculated and used.
  906. color: color of the text. Refer to `matplotlib.colors` for full list
  907. of formats that are accepted.
  908. horizontal_alignment (str): see `matplotlib.text.Text`
  909. rotation: rotation angle in degrees CCW
  910. Returns:
  911. output (VisImage): image object with text drawn.
  912. """
  913. if not font_size:
  914. font_size = self._default_font_size
  915. # since the text background is dark, we don't want the text to be dark
  916. color = np.maximum(list(mplc.to_rgb(color)), 0.15)
  917. color[np.argmax(color)] = max(0.8, np.max(color))
  918. def contrasting_color(rgb):
  919. """Returns 'white' or 'black' depending on which color contrasts more with the given RGB value."""
  920. # Decompose the RGB tuple
  921. R, G, B = rgb
  922. # Calculate the Y value
  923. Y = 0.299 * R + 0.587 * G + 0.114 * B
  924. # If Y value is greater than 128, it's closer to white so return black. Otherwise, return white.
  925. return "black" if Y > 128 else "white"
  926. bbox_background = contrasting_color(color * 255)
  927. x, y = self.locate_label_position(
  928. original_position=position,
  929. added_positions=added_positions,
  930. binary_mask=binary_mask,
  931. )
  932. self.output.ax.text(
  933. x,
  934. y,
  935. text,
  936. size=font_size * self.output.scale,
  937. family="sans-serif",
  938. bbox={
  939. "facecolor": bbox_background,
  940. "alpha": 0.8,
  941. "pad": 0.7,
  942. "edgecolor": "none",
  943. },
  944. verticalalignment="top",
  945. horizontalalignment=horizontal_alignment,
  946. color=color,
  947. zorder=10,
  948. rotation=rotation,
  949. )
  950. return self.output
  951. def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
  952. """
  953. Args:
  954. box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
  955. are the coordinates of the image's top left corner. x1 and y1 are the
  956. coordinates of the image's bottom right corner.
  957. alpha (float): blending efficient. Smaller values lead to more transparent masks.
  958. edge_color: color of the outline of the box. Refer to `matplotlib.colors`
  959. for full list of formats that are accepted.
  960. line_style (string): the string to use to create the outline of the boxes.
  961. Returns:
  962. output (VisImage): image object with box drawn.
  963. """
  964. x0, y0, x1, y1 = box_coord
  965. width = x1 - x0
  966. height = y1 - y0
  967. linewidth = max(self._default_font_size / 12, 1) * self.boarder_width_multiplier
  968. self.output.ax.add_patch(
  969. mpl.patches.Rectangle(
  970. (x0, y0),
  971. width,
  972. height,
  973. fill=False,
  974. edgecolor=edge_color,
  975. linewidth=linewidth * self.output.scale,
  976. alpha=alpha,
  977. linestyle=line_style,
  978. )
  979. )
  980. return self.output
  981. def draw_rotated_box_with_label(
  982. self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
  983. ):
  984. """
  985. Draw a rotated box with label on its top-left corner.
  986. Args:
  987. rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
  988. where cnt_x and cnt_y are the center coordinates of the box.
  989. w and h are the width and height of the box. angle represents how
  990. many degrees the box is rotated CCW with regard to the 0-degree box.
  991. alpha (float): blending efficient. Smaller values lead to more transparent masks.
  992. edge_color: color of the outline of the box. Refer to `matplotlib.colors`
  993. for full list of formats that are accepted.
  994. line_style (string): the string to use to create the outline of the boxes.
  995. label (string): label for rotated box. It will not be rendered when set to None.
  996. Returns:
  997. output (VisImage): image object with box drawn.
  998. """
  999. cnt_x, cnt_y, w, h, angle = rotated_box
  1000. area = w * h
  1001. # use thinner lines when the box is small
  1002. linewidth = self._default_font_size / (
  1003. 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
  1004. )
  1005. theta = angle * math.pi / 180.0
  1006. c = math.cos(theta)
  1007. s = math.sin(theta)
  1008. rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
  1009. # x: left->right ; y: top->down
  1010. rotated_rect = [
  1011. (s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect
  1012. ]
  1013. for k in range(4):
  1014. j = (k + 1) % 4
  1015. self.draw_line(
  1016. [rotated_rect[k][0], rotated_rect[j][0]],
  1017. [rotated_rect[k][1], rotated_rect[j][1]],
  1018. color=edge_color,
  1019. linestyle="--" if k == 1 else line_style,
  1020. linewidth=linewidth,
  1021. )
  1022. if label is not None:
  1023. text_pos = rotated_rect[1] # topleft corner
  1024. height_ratio = h / np.sqrt(self.output.height * self.output.width)
  1025. label_color = self._change_color_brightness(
  1026. edge_color, brightness_factor=0.7
  1027. )
  1028. font_size = (
  1029. np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
  1030. * 0.5
  1031. * self._default_font_size
  1032. )
  1033. self.draw_text(
  1034. label, text_pos, color=label_color, font_size=font_size, rotation=angle
  1035. )
  1036. return self.output
  1037. def draw_circle(self, circle_coord, color, radius=3):
  1038. """
  1039. Args:
  1040. circle_coord (list(int) or tuple(int)): contains the x and y coordinates
  1041. of the center of the circle.
  1042. color: color of the polygon. Refer to `matplotlib.colors` for a full list of
  1043. formats that are accepted.
  1044. radius (int): radius of the circle.
  1045. Returns:
  1046. output (VisImage): image object with box drawn.
  1047. """
  1048. x, y = circle_coord
  1049. self.output.ax.add_patch(
  1050. mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
  1051. )
  1052. return self.output
  1053. def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
  1054. """
  1055. Args:
  1056. x_data (list[int]): a list containing x values of all the points being drawn.
  1057. Length of list should match the length of y_data.
  1058. y_data (list[int]): a list containing y values of all the points being drawn.
  1059. Length of list should match the length of x_data.
  1060. color: color of the line. Refer to `matplotlib.colors` for a full list of
  1061. formats that are accepted.
  1062. linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
  1063. for a full list of formats that are accepted.
  1064. linewidth (float or None): width of the line. When it's None,
  1065. a default value will be computed and used.
  1066. Returns:
  1067. output (VisImage): image object with line drawn.
  1068. """
  1069. if linewidth is None:
  1070. linewidth = self._default_font_size / 3
  1071. linewidth = max(linewidth, 1)
  1072. self.output.ax.add_line(
  1073. mpl.lines.Line2D(
  1074. x_data,
  1075. y_data,
  1076. linewidth=linewidth * self.output.scale,
  1077. color=color,
  1078. linestyle=linestyle,
  1079. )
  1080. )
  1081. return self.output
  1082. def draw_binary_mask(
  1083. self,
  1084. binary_mask,
  1085. color=None,
  1086. *,
  1087. edge_color=None,
  1088. text=None,
  1089. alpha=0.7,
  1090. area_threshold=10,
  1091. ):
  1092. """
  1093. Args:
  1094. binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
  1095. W is the image width. Each value in the array is either a 0 or 1 value of uint8
  1096. type.
  1097. color: color of the mask. Refer to `matplotlib.colors` for a full list of
  1098. formats that are accepted. If None, will pick a random color.
  1099. edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
  1100. full list of formats that are accepted.
  1101. text (str): if None, will be drawn on the object
  1102. alpha (float): blending efficient. Smaller values lead to more transparent masks.
  1103. area_threshold (float): a connected component smaller than this area will not be shown.
  1104. Returns:
  1105. output (VisImage): image object with mask drawn.
  1106. """
  1107. if color is None:
  1108. color = random_color(rgb=True, maximum=1)
  1109. color = mplc.to_rgb(color)
  1110. has_valid_segment = False
  1111. binary_mask = binary_mask.astype("uint8") # opencv needs uint8
  1112. mask = GenericMask(binary_mask, self.output.height, self.output.width)
  1113. shape2d = (binary_mask.shape[0], binary_mask.shape[1])
  1114. if not mask.has_holes:
  1115. # draw polygons for regular masks
  1116. for segment in mask.polygons:
  1117. area = mask_util.area(
  1118. mask_util.frPyObjects([segment], shape2d[0], shape2d[1])
  1119. )
  1120. if area < (area_threshold or 0):
  1121. continue
  1122. has_valid_segment = True
  1123. segment = segment.reshape(-1, 2)
  1124. self.draw_polygon(
  1125. segment, color=color, edge_color=edge_color, alpha=alpha
  1126. )
  1127. else:
  1128. # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
  1129. rgba = np.zeros(shape2d + (4,), dtype="float32")
  1130. rgba[:, :, :3] = color
  1131. rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
  1132. has_valid_segment = True
  1133. self.output.ax.imshow(
  1134. rgba, extent=(0, self.output.width, self.output.height, 0)
  1135. )
  1136. if text is not None and has_valid_segment:
  1137. lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
  1138. self._draw_text_in_mask(binary_mask, text, lighter_color)
  1139. return self.output
  1140. def draw_binary_mask_with_number(
  1141. self,
  1142. binary_mask,
  1143. color=None,
  1144. *,
  1145. edge_color=None,
  1146. text=None,
  1147. label_mode="1",
  1148. alpha=0.1,
  1149. anno_mode=["Mask"],
  1150. area_threshold=10,
  1151. ):
  1152. """
  1153. Args:
  1154. binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
  1155. W is the image width. Each value in the array is either a 0 or 1 value of uint8
  1156. type.
  1157. color: color of the mask. Refer to `matplotlib.colors` for a full list of
  1158. formats that are accepted. If None, will pick a random color.
  1159. edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
  1160. full list of formats that are accepted.
  1161. text (str): if None, will be drawn on the object
  1162. alpha (float): blending efficient. Smaller values lead to more transparent masks.
  1163. area_threshold (float): a connected component smaller than this area will not be shown.
  1164. Returns:
  1165. output (VisImage): image object with mask drawn.
  1166. """
  1167. if color is None:
  1168. randint = random.randint(0, len(self.color_proposals) - 1)
  1169. color = self.color_proposals[randint]
  1170. color = mplc.to_rgb(color)
  1171. has_valid_segment = True
  1172. binary_mask = binary_mask.astype("uint8") # opencv needs uint8
  1173. mask = GenericMask(binary_mask, self.output.height, self.output.width)
  1174. shape2d = (binary_mask.shape[0], binary_mask.shape[1])
  1175. bbox = mask.bbox()
  1176. if "Mask" in anno_mode:
  1177. if not mask.has_holes:
  1178. # draw polygons for regular masks
  1179. for segment in mask.polygons:
  1180. area = mask_util.area(
  1181. mask_util.frPyObjects([segment], shape2d[0], shape2d[1])
  1182. )
  1183. if area < (area_threshold or 0):
  1184. continue
  1185. has_valid_segment = True
  1186. segment = segment.reshape(-1, 2)
  1187. self.draw_polygon(
  1188. segment, color=color, edge_color=edge_color, alpha=alpha
  1189. )
  1190. else:
  1191. # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
  1192. rgba = np.zeros(shape2d + (4,), dtype="float32")
  1193. rgba[:, :, :3] = color
  1194. rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
  1195. has_valid_segment = True
  1196. self.output.ax.imshow(
  1197. rgba, extent=(0, self.output.width, self.output.height, 0)
  1198. )
  1199. if "Box" in anno_mode:
  1200. self.draw_box(bbox, edge_color=color, alpha=0.75)
  1201. if "Mark" in anno_mode:
  1202. has_valid_segment = True
  1203. else:
  1204. has_valid_segment = False
  1205. if text is not None and has_valid_segment:
  1206. # lighter_color = tuple([x*0.2 for x in color])
  1207. lighter_color = [
  1208. 1,
  1209. 1,
  1210. 1,
  1211. ] # self._change_color_brightness(color, brightness_factor=0.7)
  1212. self._draw_number_in_mask(
  1213. binary_mask=binary_mask,
  1214. text=text,
  1215. color=lighter_color,
  1216. label_mode=label_mode,
  1217. )
  1218. return self.output
  1219. def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
  1220. """
  1221. Args:
  1222. soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
  1223. color: color of the mask. Refer to `matplotlib.colors` for a full list of
  1224. formats that are accepted. If None, will pick a random color.
  1225. text (str): if None, will be drawn on the object
  1226. alpha (float): blending efficient. Smaller values lead to more transparent masks.
  1227. Returns:
  1228. output (VisImage): image object with mask drawn.
  1229. """
  1230. if color is None:
  1231. color = random_color(rgb=True, maximum=1)
  1232. color = mplc.to_rgb(color)
  1233. shape2d = (soft_mask.shape[0], soft_mask.shape[1])
  1234. rgba = np.zeros(shape2d + (4,), dtype="float32")
  1235. rgba[:, :, :3] = color
  1236. rgba[:, :, 3] = soft_mask * alpha
  1237. self.output.ax.imshow(
  1238. rgba, extent=(0, self.output.width, self.output.height, 0)
  1239. )
  1240. if text is not None:
  1241. lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
  1242. binary_mask = (soft_mask > 0.5).astype("uint8")
  1243. self._draw_text_in_mask(binary_mask, text, lighter_color)
  1244. return self.output
  1245. def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
  1246. """
  1247. Args:
  1248. segment: numpy array of shape Nx2, containing all the points in the polygon.
  1249. color: color of the polygon. Refer to `matplotlib.colors` for a full list of
  1250. formats that are accepted.
  1251. edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
  1252. full list of formats that are accepted. If not provided, a darker shade
  1253. of the polygon color will be used instead.
  1254. alpha (float): blending efficient. Smaller values lead to more transparent masks.
  1255. Returns:
  1256. output (VisImage): image object with polygon drawn.
  1257. """
  1258. if edge_color is None:
  1259. # make edge color darker than the polygon color
  1260. if alpha > 0.8:
  1261. edge_color = self._change_color_brightness(
  1262. color, brightness_factor=-0.7
  1263. )
  1264. else:
  1265. edge_color = color
  1266. edge_color = mplc.to_rgb(edge_color) + (1,)
  1267. polygon = mpl.patches.Polygon(
  1268. segment,
  1269. fill=True,
  1270. facecolor=mplc.to_rgb(color) + (alpha,),
  1271. edgecolor=edge_color,
  1272. linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
  1273. )
  1274. self.output.ax.add_patch(polygon)
  1275. return self.output
  1276. """
  1277. Internal methods:
  1278. """
  1279. def _jitter(self, color):
  1280. """
  1281. Randomly modifies given color to produce a slightly different color than the color given.
  1282. Args:
  1283. color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
  1284. picked. The values in the list are in the [0.0, 1.0] range.
  1285. Returns:
  1286. jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
  1287. color after being jittered. The values in the list are in the [0.0, 1.0] range.
  1288. """
  1289. color = mplc.to_rgb(color)
  1290. # np.random.seed(0)
  1291. vec = np.random.rand(3)
  1292. # better to do it in another color space
  1293. vec = vec / np.linalg.norm(vec) * 0.5
  1294. res = np.clip(vec + color, 0, 1)
  1295. return tuple(res)
  1296. def _create_grayscale_image(self, mask=None):
  1297. """
  1298. Create a grayscale version of the original image.
  1299. The colors in masked area, if given, will be kept.
  1300. """
  1301. img_bw = self.img.astype("f4").mean(axis=2)
  1302. img_bw = np.stack([img_bw] * 3, axis=2)
  1303. if mask is not None:
  1304. img_bw[mask] = self.img[mask]
  1305. return img_bw
  1306. def _change_color_brightness(self, color, brightness_factor):
  1307. """
  1308. Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
  1309. less or more saturation than the original color.
  1310. Args:
  1311. color: color of the polygon. Refer to `matplotlib.colors` for a full list of
  1312. formats that are accepted.
  1313. brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
  1314. 0 will correspond to no change, a factor in [-1.0, 0) range will result in
  1315. a darker color and a factor in (0, 1.0] range will result in a lighter color.
  1316. Returns:
  1317. modified_color (tuple[double]): a tuple containing the RGB values of the
  1318. modified color. Each value in the tuple is in the [0.0, 1.0] range.
  1319. """
  1320. assert brightness_factor >= -1.0 and brightness_factor <= 1.0
  1321. color = mplc.to_rgb(color)
  1322. polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
  1323. modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
  1324. modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
  1325. modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
  1326. modified_color = colorsys.hls_to_rgb(
  1327. polygon_color[0], modified_lightness, polygon_color[2]
  1328. )
  1329. return modified_color
  1330. def _convert_boxes(self, boxes):
  1331. """
  1332. Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
  1333. """
  1334. if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
  1335. return boxes.tensor.detach().numpy()
  1336. else:
  1337. return np.asarray(boxes)
  1338. def _convert_masks(self, masks_or_polygons):
  1339. """
  1340. Convert different format of masks or polygons to a tuple of masks and polygons.
  1341. Returns:
  1342. list[GenericMask]:
  1343. """
  1344. m = masks_or_polygons
  1345. if isinstance(m, PolygonMasks):
  1346. m = m.polygons
  1347. if isinstance(m, BitMasks):
  1348. m = m.tensor.numpy()
  1349. if isinstance(m, torch.Tensor):
  1350. m = m.numpy()
  1351. ret = []
  1352. for x in m:
  1353. if isinstance(x, GenericMask):
  1354. ret.append(x)
  1355. else:
  1356. ret.append(GenericMask(x, self.output.height, self.output.width))
  1357. return ret
  1358. def _draw_number_in_box(self, box, text, color, label_mode="1"):
  1359. """
  1360. Find proper places to draw text given a box.
  1361. """
  1362. x0, y0, x1, y1 = box
  1363. text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
  1364. horiz_align = "left"
  1365. # for small objects, draw text at the side to avoid occlusion
  1366. instance_area = (y1 - y0) * (x1 - x0)
  1367. if (
  1368. instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
  1369. or y1 - y0 < 40 * self.output.scale
  1370. ):
  1371. if y1 >= self.output.height - 5:
  1372. text_pos = (x1, y0)
  1373. else:
  1374. text_pos = (x0, y1)
  1375. height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
  1376. lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
  1377. font_size = (
  1378. np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
  1379. * 0.65
  1380. * self._default_font_size
  1381. )
  1382. if label_mode == "a":
  1383. text = self.number_to_string(int(text))
  1384. else:
  1385. text = text
  1386. self.draw_text(
  1387. text,
  1388. text_pos,
  1389. color=lighter_color,
  1390. horizontal_alignment=horiz_align,
  1391. font_size=font_size,
  1392. )
  1393. return str(text)
  1394. @staticmethod
  1395. def number_to_string(n):
  1396. chars = []
  1397. while n:
  1398. n, remainder = divmod(n - 1, 26)
  1399. chars.append(chr(97 + remainder))
  1400. return "".join(reversed(chars))
  1401. def _draw_number_in_mask(
  1402. self, binary_mask, text, color, added_positions=None, label_mode="1"
  1403. ):
  1404. """
  1405. Find proper places to draw text given a binary mask.
  1406. """
  1407. binary_mask = np.pad(binary_mask, ((1, 1), (1, 1)), "constant")
  1408. mask_dt = cv2.distanceTransform(binary_mask, cv2.DIST_L2, 0)
  1409. mask_dt = mask_dt[1:-1, 1:-1]
  1410. max_dist = np.max(mask_dt)
  1411. coords_y, coords_x = np.where(mask_dt == max_dist) # coords is [y, x]
  1412. if label_mode == "a":
  1413. text = self.number_to_string(int(text))
  1414. else:
  1415. text = text
  1416. text_position = (
  1417. coords_x[len(coords_x) // 2] + 2,
  1418. coords_y[len(coords_y) // 2] - 6,
  1419. )
  1420. self.draw_text(
  1421. text,
  1422. text_position,
  1423. added_positions=added_positions,
  1424. binary_mask=binary_mask,
  1425. color=color,
  1426. )
  1427. return str(text), text_position
  1428. # _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
  1429. # if stats[1:, -1].size == 0:
  1430. # return
  1431. # largest_component_id = np.argmax(stats[1:, -1]) + 1
  1432. # # draw text on the largest component, as well as other very large components.
  1433. # for cid in range(1, _num_cc):
  1434. # if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
  1435. # # median is more stable than centroid
  1436. # # center = centroids[largest_component_id]
  1437. # center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
  1438. # # bottom=np.max((cc_labels == cid).nonzero(), axis=1)[::-1]
  1439. # # center[1]=bottom[1]+2
  1440. # self.draw_text(text, center, color=color)
  1441. def _draw_text_in_mask(self, binary_mask, text, color):
  1442. """
  1443. Find proper places to draw text given a binary mask.
  1444. """
  1445. _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(
  1446. binary_mask, 8
  1447. )
  1448. if stats[1:, -1].size == 0:
  1449. return
  1450. largest_component_id = np.argmax(stats[1:, -1]) + 1
  1451. # draw text on the largest component, as well as other very large components.
  1452. for cid in range(1, _num_cc):
  1453. if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
  1454. # median is more stable than centroid
  1455. # center = centroids[largest_component_id]
  1456. center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
  1457. bottom = np.max((cc_labels == cid).nonzero(), axis=1)[::-1]
  1458. center[1] = bottom[1] + 2
  1459. self.draw_text(text, center, color=color)
  1460. def _convert_keypoints(self, keypoints):
  1461. if isinstance(keypoints, Keypoints):
  1462. keypoints = keypoints.tensor
  1463. keypoints = np.asarray(keypoints)
  1464. return keypoints
  1465. def get_output(self):
  1466. """
  1467. Returns:
  1468. output (VisImage): the image output containing the visualizations added
  1469. to the image.
  1470. """
  1471. return self.output