zoom_in.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import io
  4. import math
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import pycocotools.mask as mask_utils
  8. from PIL import Image
  9. from .som_utils import ColorPalette, draw_box, draw_mask, draw_text
  10. def render_zoom_in(
  11. object_data,
  12. image_file,
  13. show_box: bool = True,
  14. show_text: bool = False,
  15. show_holes: bool = True,
  16. mask_alpha: float = 0.15,
  17. ):
  18. """
  19. Render a two-panel visualization with a cropped original view (left/upper) and a zoomed-in
  20. mask overlay (right/lower), then return it as a PIL.Image along with the chosen mask color (hex).
  21. Parameters
  22. ----------
  23. object_data : dict
  24. Dict containing "labels" and COCO RLE "segmentation".
  25. Expected:
  26. object_data["labels"][0]["noun_phrase"] : str
  27. object_data["segmentation"] : COCO RLE (with "size": [H, W])
  28. image_file : PIL.Image.Image
  29. Source image (PIL).
  30. show_box : bool
  31. Whether to draw the bbox on the cropped original panel.
  32. show_text : bool
  33. Whether to draw the noun phrase label near the bbox.
  34. show_holes : bool
  35. Whether to render mask holes (passed through to draw_mask).
  36. mask_alpha : float
  37. Alpha for the mask overlay.
  38. Returns
  39. -------
  40. pil_img : PIL.Image.Image
  41. The composed visualization image.
  42. color_hex : str
  43. Hex string of the chosen mask color.
  44. """
  45. # ---- local constants (avoid module-level globals) ----
  46. _AREA_LARGE = 0.25
  47. _AREA_MEDIUM = 0.05
  48. # ---- local helpers (avoid name collisions in a larger class) ----
  49. def _get_shift(x, w, w_new, w_img):
  50. assert 0 <= w_new <= w_img
  51. shift = (w_new - w) / 2
  52. if x - shift + w_new > w_img:
  53. shift = x + w_new - w_img
  54. return min(x, shift)
  55. def _get_zoom_in_box(mask_box_xywh, img_h, img_w, mask_area):
  56. box_w, box_h = mask_box_xywh[2], mask_box_xywh[3]
  57. w_new = min(box_w + max(0.2 * box_w, 16), img_w)
  58. h_new = min(box_h + max(0.2 * box_h, 16), img_h)
  59. mask_relative_area = mask_area / (w_new * h_new)
  60. # zoom-in (larger box if mask is relatively big)
  61. w_new_large, h_new_large = w_new, h_new
  62. if mask_relative_area > _AREA_LARGE:
  63. ratio_large = math.sqrt(mask_relative_area / _AREA_LARGE)
  64. w_new_large = min(w_new * ratio_large, img_w)
  65. h_new_large = min(h_new * ratio_large, img_h)
  66. w_shift_large = _get_shift(
  67. mask_box_xywh[0], mask_box_xywh[2], w_new_large, img_w
  68. )
  69. h_shift_large = _get_shift(
  70. mask_box_xywh[1], mask_box_xywh[3], h_new_large, img_h
  71. )
  72. zoom_in_box = [
  73. mask_box_xywh[0] - w_shift_large,
  74. mask_box_xywh[1] - h_shift_large,
  75. w_new_large,
  76. h_new_large,
  77. ]
  78. # crop box for the original/cropped image
  79. w_new_medium, h_new_medium = w_new, h_new
  80. if mask_relative_area > _AREA_MEDIUM:
  81. ratio_med = math.sqrt(mask_relative_area / _AREA_MEDIUM)
  82. w_new_medium = min(w_new * ratio_med, img_w)
  83. h_new_medium = min(h_new * ratio_med, img_h)
  84. w_shift_medium = _get_shift(
  85. mask_box_xywh[0], mask_box_xywh[2], w_new_medium, img_w
  86. )
  87. h_shift_medium = _get_shift(
  88. mask_box_xywh[1], mask_box_xywh[3], h_new_medium, img_h
  89. )
  90. img_crop_box = [
  91. mask_box_xywh[0] - w_shift_medium,
  92. mask_box_xywh[1] - h_shift_medium,
  93. w_new_medium,
  94. h_new_medium,
  95. ]
  96. return zoom_in_box, img_crop_box
  97. # ---- main body ----
  98. # Input parsing
  99. object_label = object_data["labels"][0]["noun_phrase"]
  100. img = image_file.convert("RGB")
  101. bbox_xywh = mask_utils.toBbox(object_data["segmentation"]) # [x, y, w, h]
  102. # Choose a stable, visually distant color based on crop
  103. bbox_xyxy = [
  104. bbox_xywh[0],
  105. bbox_xywh[1],
  106. bbox_xywh[0] + bbox_xywh[2],
  107. bbox_xywh[1] + bbox_xywh[3],
  108. ]
  109. crop_img = img.crop(bbox_xyxy)
  110. color_palette = ColorPalette.default()
  111. color_obj, _ = color_palette.find_farthest_color(np.array(crop_img))
  112. color = np.array([color_obj.r / 255, color_obj.g / 255, color_obj.b / 255])
  113. color_hex = f"#{color_obj.r:02x}{color_obj.g:02x}{color_obj.b:02x}"
  114. # Compute zoom-in / crop boxes
  115. img_h, img_w = object_data["segmentation"]["size"]
  116. mask_area = mask_utils.area(object_data["segmentation"])
  117. zoom_in_box, img_crop_box = _get_zoom_in_box(bbox_xywh, img_h, img_w, mask_area)
  118. # Layout choice
  119. w, h = img_crop_box[2], img_crop_box[3]
  120. if w < h:
  121. fig, (ax1, ax2) = plt.subplots(1, 2)
  122. else:
  123. fig, (ax1, ax2) = plt.subplots(2, 1)
  124. # Panel 1: cropped original with optional box/text
  125. img_crop_box_xyxy = [
  126. img_crop_box[0],
  127. img_crop_box[1],
  128. img_crop_box[0] + img_crop_box[2],
  129. img_crop_box[1] + img_crop_box[3],
  130. ]
  131. img1 = img.crop(img_crop_box_xyxy)
  132. bbox_xywh_rel = [
  133. bbox_xywh[0] - img_crop_box[0],
  134. bbox_xywh[1] - img_crop_box[1],
  135. bbox_xywh[2],
  136. bbox_xywh[3],
  137. ]
  138. ax1.imshow(img1)
  139. ax1.axis("off")
  140. if show_box:
  141. draw_box(ax1, bbox_xywh_rel, edge_color=color)
  142. if show_text:
  143. x0, y0 = bbox_xywh_rel[0] + 2, bbox_xywh_rel[1] + 2
  144. draw_text(ax1, object_label, [x0, y0], color=color)
  145. # Panel 2: zoomed-in mask overlay
  146. binary_mask = mask_utils.decode(object_data["segmentation"])
  147. alpha = Image.fromarray((binary_mask * 255).astype("uint8"))
  148. img_rgba = img.convert("RGBA")
  149. img_rgba.putalpha(alpha)
  150. zoom_in_box_xyxy = [
  151. zoom_in_box[0],
  152. zoom_in_box[1],
  153. zoom_in_box[0] + zoom_in_box[2],
  154. zoom_in_box[1] + zoom_in_box[3],
  155. ]
  156. img_with_alpha_zoomin = img_rgba.crop(zoom_in_box_xyxy)
  157. alpha_zoomin = img_with_alpha_zoomin.split()[3]
  158. binary_mask_zoomin = np.array(alpha_zoomin).astype(bool)
  159. ax2.imshow(img_with_alpha_zoomin.convert("RGB"))
  160. ax2.axis("off")
  161. draw_mask(
  162. ax2, binary_mask_zoomin, color=color, show_holes=show_holes, alpha=mask_alpha
  163. )
  164. plt.tight_layout()
  165. # Buffer -> PIL.Image
  166. buf = io.BytesIO()
  167. fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=100)
  168. plt.close(fig)
  169. buf.seek(0)
  170. pil_img = Image.open(buf)
  171. return pil_img, color_hex