viz.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import cv2
  4. import numpy as np
  5. import pycocotools.mask as mask_utils
  6. from PIL import Image
  7. from .helpers.visualizer import Visualizer
  8. from .helpers.zoom_in import render_zoom_in
  9. def visualize(
  10. input_json: dict,
  11. zoom_in_index: int | None = None,
  12. mask_alpha: float = 0.15,
  13. label_mode: str = "1",
  14. font_size_multiplier: float = 1.2,
  15. boarder_width_multiplier: float = 0,
  16. ):
  17. """
  18. Unified visualization function.
  19. If zoom_in_index is None:
  20. - Render all masks in input_json (equivalent to visualize_masks_from_result_json).
  21. - Returns: PIL.Image
  22. If zoom_in_index is provided:
  23. - Returns two PIL.Images:
  24. 1) Output identical to zoom_in_and_visualize(input_json, index).
  25. 2) The same instance rendered via the general overlay using the color
  26. returned by (1), equivalent to calling visualize_masks_from_result_json
  27. on a single-mask json_i with color=color_hex.
  28. """
  29. # Common fields
  30. orig_h = int(input_json["orig_img_h"])
  31. orig_w = int(input_json["orig_img_w"])
  32. img_path = input_json["original_image_path"]
  33. # ---------- Mode A: Full-scene render ----------
  34. if zoom_in_index is None:
  35. boxes = np.array(input_json["pred_boxes"])
  36. rle_masks = [
  37. {"size": (orig_h, orig_w), "counts": rle}
  38. for rle in input_json["pred_masks"]
  39. ]
  40. binary_masks = [mask_utils.decode(rle) for rle in rle_masks]
  41. img_bgr = cv2.imread(img_path)
  42. if img_bgr is None:
  43. raise FileNotFoundError(f"Could not read image: {img_path}")
  44. img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  45. viz = Visualizer(
  46. img_rgb,
  47. font_size_multiplier=font_size_multiplier,
  48. boarder_width_multiplier=boarder_width_multiplier,
  49. )
  50. viz.overlay_instances(
  51. boxes=boxes,
  52. masks=rle_masks,
  53. binary_masks=binary_masks,
  54. assigned_colors=None,
  55. alpha=mask_alpha,
  56. label_mode=label_mode,
  57. )
  58. pil_all_masks = Image.fromarray(viz.output.get_image())
  59. return pil_all_masks
  60. # ---------- Mode B: Zoom-in pair ----------
  61. else:
  62. idx = int(zoom_in_index)
  63. num_masks = len(input_json.get("pred_masks", []))
  64. if idx < 0 or idx >= num_masks:
  65. raise ValueError(
  66. f"zoom_in_index {idx} is out of range (0..{num_masks - 1})."
  67. )
  68. # (1) Replicate zoom_in_and_visualize
  69. object_data = {
  70. "labels": [{"noun_phrase": f"mask_{idx}"}],
  71. "segmentation": {
  72. "counts": input_json["pred_masks"][idx],
  73. "size": [orig_h, orig_w],
  74. },
  75. }
  76. pil_img = Image.open(img_path)
  77. pil_mask_i_zoomed, color_hex = render_zoom_in(
  78. object_data, pil_img, mask_alpha=mask_alpha
  79. )
  80. # (2) Single-instance render with the same color
  81. boxes_i = np.array([input_json["pred_boxes"][idx]])
  82. rle_i = {"size": (orig_h, orig_w), "counts": input_json["pred_masks"][idx]}
  83. bin_i = mask_utils.decode(rle_i)
  84. img_bgr_i = cv2.imread(img_path)
  85. if img_bgr_i is None:
  86. raise FileNotFoundError(f"Could not read image: {img_path}")
  87. img_rgb_i = cv2.cvtColor(img_bgr_i, cv2.COLOR_BGR2RGB)
  88. viz_i = Visualizer(
  89. img_rgb_i,
  90. font_size_multiplier=font_size_multiplier,
  91. boarder_width_multiplier=boarder_width_multiplier,
  92. )
  93. viz_i.overlay_instances(
  94. boxes=boxes_i,
  95. masks=[rle_i],
  96. binary_masks=[bin_i],
  97. assigned_colors=[color_hex],
  98. alpha=mask_alpha,
  99. label_mode=label_mode,
  100. )
  101. pil_mask_i = Image.fromarray(viz_i.output.get_image())
  102. return pil_mask_i, pil_mask_i_zoomed