visualization_utils.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import json
  4. import os
  5. import subprocess
  6. from pathlib import Path
  7. import cv2
  8. import matplotlib.patches as patches
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. import pandas as pd
  12. import pycocotools.mask as mask_utils
  13. import torch
  14. from matplotlib.colors import to_rgb
  15. from PIL import Image
  16. from skimage.color import lab2rgb, rgb2lab
  17. from sklearn.cluster import KMeans
  18. from torchvision.ops import masks_to_boxes
  19. from tqdm import tqdm
  20. def generate_colors(n_colors=256, n_samples=5000):
  21. # Step 1: Random RGB samples
  22. np.random.seed(42)
  23. rgb = np.random.rand(n_samples, 3)
  24. # Step 2: Convert to LAB for perceptual uniformity
  25. # print(f"Converting {n_samples} RGB samples to LAB color space...")
  26. lab = rgb2lab(rgb.reshape(1, -1, 3)).reshape(-1, 3)
  27. # print("Conversion to LAB complete.")
  28. # Step 3: k-means clustering in LAB
  29. kmeans = KMeans(n_clusters=n_colors, n_init=10)
  30. # print(f"Fitting KMeans with {n_colors} clusters on {n_samples} samples...")
  31. kmeans.fit(lab)
  32. # print("KMeans fitting complete.")
  33. centers_lab = kmeans.cluster_centers_
  34. # Step 4: Convert LAB back to RGB
  35. colors_rgb = lab2rgb(centers_lab.reshape(1, -1, 3)).reshape(-1, 3)
  36. colors_rgb = np.clip(colors_rgb, 0, 1)
  37. return colors_rgb
  38. COLORS = generate_colors(n_colors=128, n_samples=5000)
  39. def show_img_tensor(img_batch, vis_img_idx=0):
  40. MEAN_IMG = np.array([0.5, 0.5, 0.5])
  41. STD_IMG = np.array([0.5, 0.5, 0.5])
  42. im_tensor = img_batch[vis_img_idx].detach().cpu()
  43. assert im_tensor.dim() == 3
  44. im_tensor = im_tensor.numpy().transpose((1, 2, 0))
  45. im_tensor = (im_tensor * STD_IMG) + MEAN_IMG
  46. im_tensor = np.clip(im_tensor, 0, 1)
  47. plt.imshow(im_tensor)
  48. def draw_box_on_image(image, box, color=(0, 255, 0)):
  49. """
  50. Draws a rectangle on a given PIL image using the provided box coordinates in xywh format.
  51. :param image: PIL.Image - The image on which to draw the rectangle.
  52. :param box: tuple - A tuple (x, y, w, h) representing the top-left corner, width, and height of the rectangle.
  53. :param color: tuple - A tuple (R, G, B) representing the color of the rectangle. Default is red.
  54. :return: PIL.Image - The image with the rectangle drawn on it.
  55. """
  56. # Ensure the image is in RGB mode
  57. image = image.convert("RGB")
  58. # Unpack the box coordinates
  59. x, y, w, h = box
  60. x, y, w, h = int(x), int(y), int(w), int(h)
  61. # Get the pixel data
  62. pixels = image.load()
  63. # Draw the top and bottom edges
  64. for i in range(x, x + w):
  65. pixels[i, y] = color
  66. pixels[i, y + h - 1] = color
  67. pixels[i, y + 1] = color
  68. pixels[i, y + h] = color
  69. pixels[i, y - 1] = color
  70. pixels[i, y + h - 2] = color
  71. # Draw the left and right edges
  72. for j in range(y, y + h):
  73. pixels[x, j] = color
  74. pixels[x + 1, j] = color
  75. pixels[x - 1, j] = color
  76. pixels[x + w - 1, j] = color
  77. pixels[x + w, j] = color
  78. pixels[x + w - 2, j] = color
  79. return image
  80. def plot_bbox(
  81. img_height,
  82. img_width,
  83. box,
  84. box_format="XYXY",
  85. relative_coords=True,
  86. color="r",
  87. linestyle="solid",
  88. text=None,
  89. ax=None,
  90. ):
  91. if box_format == "XYXY":
  92. x, y, x2, y2 = box
  93. w = x2 - x
  94. h = y2 - y
  95. elif box_format == "XYWH":
  96. x, y, w, h = box
  97. elif box_format == "CxCyWH":
  98. cx, cy, w, h = box
  99. x = cx - w / 2
  100. y = cy - h / 2
  101. else:
  102. raise RuntimeError(f"Invalid box_format {box_format}")
  103. if relative_coords:
  104. x *= img_width
  105. w *= img_width
  106. y *= img_height
  107. h *= img_height
  108. if ax is None:
  109. ax = plt.gca()
  110. rect = patches.Rectangle(
  111. (x, y),
  112. w,
  113. h,
  114. linewidth=1.5,
  115. edgecolor=color,
  116. facecolor="none",
  117. linestyle=linestyle,
  118. )
  119. ax.add_patch(rect)
  120. if text is not None:
  121. facecolor = "w"
  122. ax.text(
  123. x,
  124. y - 5,
  125. text,
  126. color=color,
  127. weight="bold",
  128. fontsize=8,
  129. bbox={"facecolor": facecolor, "alpha": 0.75, "pad": 2},
  130. )
  131. def plot_mask(mask, color="r", ax=None):
  132. im_h, im_w = mask.shape
  133. mask_img = np.zeros((im_h, im_w, 4), dtype=np.float32)
  134. mask_img[..., :3] = to_rgb(color)
  135. mask_img[..., 3] = mask * 0.5
  136. # Use the provided ax or the current axis
  137. if ax is None:
  138. ax = plt.gca()
  139. ax.imshow(mask_img)
  140. def normalize_bbox(bbox_xywh, img_w, img_h):
  141. # Assumes bbox_xywh is in XYWH format
  142. if isinstance(bbox_xywh, list):
  143. assert len(bbox_xywh) == 4, (
  144. "bbox_xywh list must have 4 elements. Batching not support except for torch tensors."
  145. )
  146. normalized_bbox = bbox_xywh.copy()
  147. normalized_bbox[0] /= img_w
  148. normalized_bbox[1] /= img_h
  149. normalized_bbox[2] /= img_w
  150. normalized_bbox[3] /= img_h
  151. else:
  152. assert isinstance(bbox_xywh, torch.Tensor), (
  153. "Only torch tensors are supported for batching."
  154. )
  155. normalized_bbox = bbox_xywh.clone()
  156. assert normalized_bbox.size(-1) == 4, (
  157. "bbox_xywh tensor must have last dimension of size 4."
  158. )
  159. normalized_bbox[..., 0] /= img_w
  160. normalized_bbox[..., 1] /= img_h
  161. normalized_bbox[..., 2] /= img_w
  162. normalized_bbox[..., 3] /= img_h
  163. return normalized_bbox
  164. def visualize_frame_output(frame_idx, video_frames, outputs, figsize=(12, 8)):
  165. plt.figure(figsize=figsize)
  166. plt.title(f"frame {frame_idx}")
  167. img = load_frame(video_frames[frame_idx])
  168. img_H, img_W, _ = img.shape
  169. plt.imshow(img)
  170. for i in range(len(outputs["out_probs"])):
  171. box_xywh = outputs["out_boxes_xywh"][i]
  172. prob = outputs["out_probs"][i]
  173. obj_id = outputs["out_obj_ids"][i]
  174. binary_mask = outputs["out_binary_masks"][i]
  175. color = COLORS[obj_id % len(COLORS)]
  176. plot_bbox(
  177. img_H,
  178. img_W,
  179. box_xywh,
  180. text=f"(id={obj_id}, {prob=:.2f})",
  181. box_format="XYWH",
  182. color=color,
  183. )
  184. plot_mask(binary_mask, color=color)
  185. def visualize_formatted_frame_output(
  186. frame_idx,
  187. video_frames,
  188. outputs_list,
  189. titles=None,
  190. points_list=None,
  191. points_labels_list=None,
  192. figsize=(12, 8),
  193. title_suffix="",
  194. prompt_info=None,
  195. ):
  196. """Visualize up to three sets of segmentation masks on a video frame.
  197. Args:
  198. frame_idx: Frame index to visualize
  199. image_files: List of image file paths
  200. outputs_list: List of {frame_idx: {obj_id: mask_tensor}} or single dict {obj_id: mask_tensor}
  201. titles: List of titles for each set of outputs_list
  202. points_list: Optional list of point coordinates
  203. points_labels_list: Optional list of point labels
  204. figsize: Figure size tuple
  205. save: Whether to save the visualization to file
  206. output_dir: Base output directory when saving
  207. scenario_name: Scenario name for organizing saved files
  208. title_suffix: Additional title suffix
  209. prompt_info: Dictionary with prompt information (boxes, points, etc.)
  210. """
  211. # Handle single output dict case
  212. if isinstance(outputs_list, dict) and frame_idx in outputs_list:
  213. # This is a single outputs dict with frame indices as keys
  214. outputs_list = [outputs_list]
  215. elif isinstance(outputs_list, dict) and not any(
  216. isinstance(k, int) for k in outputs_list.keys()
  217. ):
  218. # This is a single frame's outputs {obj_id: mask}
  219. single_frame_outputs = {frame_idx: outputs_list}
  220. outputs_list = [single_frame_outputs]
  221. num_outputs = len(outputs_list)
  222. if titles is None:
  223. titles = [f"Set {i + 1}" for i in range(num_outputs)]
  224. assert len(titles) == num_outputs, (
  225. "length of `titles` should match that of `outputs_list` if not None."
  226. )
  227. _, axes = plt.subplots(1, num_outputs, figsize=figsize)
  228. if num_outputs == 1:
  229. axes = [axes] # Make it iterable
  230. img = load_frame(video_frames[frame_idx])
  231. img_H, img_W, _ = img.shape
  232. for idx in range(num_outputs):
  233. ax, outputs_set, ax_title = axes[idx], outputs_list[idx], titles[idx]
  234. ax.set_title(f"Frame {frame_idx} - {ax_title}{title_suffix}")
  235. ax.imshow(img)
  236. if frame_idx in outputs_set:
  237. _outputs = outputs_set[frame_idx]
  238. else:
  239. print(f"Warning: Frame {frame_idx} not found in outputs_set")
  240. continue
  241. if prompt_info and frame_idx == 0: # Show prompts on first frame
  242. if "boxes" in prompt_info:
  243. for box in prompt_info["boxes"]:
  244. # box is in [x, y, w, h] normalized format
  245. x, y, w, h = box
  246. plot_bbox(
  247. img_H,
  248. img_W,
  249. [x, y, x + w, y + h], # Convert to XYXY
  250. box_format="XYXY",
  251. relative_coords=True,
  252. color="yellow",
  253. linestyle="dashed",
  254. text="PROMPT BOX",
  255. ax=ax,
  256. )
  257. if "points" in prompt_info and "point_labels" in prompt_info:
  258. points = np.array(prompt_info["points"])
  259. labels = np.array(prompt_info["point_labels"])
  260. # Convert normalized to pixel coordinates
  261. points_pixel = points * np.array([img_W, img_H])
  262. # Draw positive points (green stars)
  263. pos_points = points_pixel[labels == 1]
  264. if len(pos_points) > 0:
  265. ax.scatter(
  266. pos_points[:, 0],
  267. pos_points[:, 1],
  268. color="lime",
  269. marker="*",
  270. s=200,
  271. edgecolor="white",
  272. linewidth=2,
  273. label="Positive Points",
  274. zorder=10,
  275. )
  276. # Draw negative points (red stars)
  277. neg_points = points_pixel[labels == 0]
  278. if len(neg_points) > 0:
  279. ax.scatter(
  280. neg_points[:, 0],
  281. neg_points[:, 1],
  282. color="red",
  283. marker="*",
  284. s=200,
  285. edgecolor="white",
  286. linewidth=2,
  287. label="Negative Points",
  288. zorder=10,
  289. )
  290. objects_drawn = 0
  291. for obj_id, binary_mask in _outputs.items():
  292. mask_sum = (
  293. binary_mask.sum()
  294. if hasattr(binary_mask, "sum")
  295. else np.sum(binary_mask)
  296. )
  297. if mask_sum > 0: # Only draw if mask has content
  298. # Convert to torch tensor if it's not already
  299. if not isinstance(binary_mask, torch.Tensor):
  300. binary_mask = torch.tensor(binary_mask)
  301. # Find bounding box from mask
  302. if binary_mask.any():
  303. box_xyxy = masks_to_boxes(binary_mask.unsqueeze(0)).squeeze()
  304. box_xyxy = normalize_bbox(box_xyxy, img_W, img_H)
  305. else:
  306. # Fallback: create a small box at center
  307. box_xyxy = [0.45, 0.45, 0.55, 0.55]
  308. color = COLORS[obj_id % len(COLORS)]
  309. plot_bbox(
  310. img_H,
  311. img_W,
  312. box_xyxy,
  313. text=f"(id={obj_id})",
  314. box_format="XYXY",
  315. color=color,
  316. ax=ax,
  317. )
  318. # Convert back to numpy for plotting
  319. mask_np = (
  320. binary_mask.numpy()
  321. if isinstance(binary_mask, torch.Tensor)
  322. else binary_mask
  323. )
  324. plot_mask(mask_np, color=color, ax=ax)
  325. objects_drawn += 1
  326. if objects_drawn == 0:
  327. ax.text(
  328. 0.5,
  329. 0.5,
  330. "No objects detected",
  331. transform=ax.transAxes,
  332. fontsize=16,
  333. ha="center",
  334. va="center",
  335. color="red",
  336. weight="bold",
  337. )
  338. # Draw additional points if provided
  339. if points_list is not None and points_list[idx] is not None:
  340. show_points(
  341. points_list[idx], points_labels_list[idx], ax=ax, marker_size=200
  342. )
  343. ax.axis("off")
  344. plt.tight_layout()
  345. plt.show()
  346. def render_masklet_frame(img, outputs, frame_idx=None, alpha=0.5):
  347. """
  348. Overlays masklets and bounding boxes on a single image frame.
  349. Args:
  350. img: np.ndarray, shape (H, W, 3), uint8 or float32 in [0,255] or [0,1]
  351. outputs: dict with keys: out_boxes_xywh, out_probs, out_obj_ids, out_binary_masks
  352. frame_idx: int or None, for overlaying frame index text
  353. alpha: float, mask overlay alpha
  354. Returns:
  355. overlay: np.ndarray, shape (H, W, 3), uint8
  356. """
  357. if img.dtype == np.float32 or img.max() <= 1.0:
  358. img = (img * 255).astype(np.uint8)
  359. img = img[..., :3] # drop alpha if present
  360. height, width = img.shape[:2]
  361. overlay = img.copy()
  362. for i in range(len(outputs["out_probs"])):
  363. obj_id = outputs["out_obj_ids"][i]
  364. color = COLORS[obj_id % len(COLORS)]
  365. color255 = (color * 255).astype(np.uint8)
  366. mask = outputs["out_binary_masks"][i]
  367. if mask.shape != img.shape[:2]:
  368. mask = cv2.resize(
  369. mask.astype(np.float32),
  370. (img.shape[1], img.shape[0]),
  371. interpolation=cv2.INTER_NEAREST,
  372. )
  373. mask_bool = mask > 0.5
  374. for c in range(3):
  375. overlay[..., c][mask_bool] = (
  376. alpha * color255[c] + (1 - alpha) * overlay[..., c][mask_bool]
  377. ).astype(np.uint8)
  378. # Draw bounding boxes and text
  379. for i in range(len(outputs["out_probs"])):
  380. box_xywh = outputs["out_boxes_xywh"][i]
  381. obj_id = outputs["out_obj_ids"][i]
  382. prob = outputs["out_probs"][i]
  383. color = COLORS[obj_id % len(COLORS)]
  384. color255 = tuple(int(x * 255) for x in color)
  385. x, y, w, h = box_xywh
  386. x1 = int(x * width)
  387. y1 = int(y * height)
  388. x2 = int((x + w) * width)
  389. y2 = int((y + h) * height)
  390. cv2.rectangle(overlay, (x1, y1), (x2, y2), color255, 2)
  391. if prob is not None:
  392. label = f"id={obj_id}, p={prob:.2f}"
  393. else:
  394. label = f"id={obj_id}"
  395. cv2.putText(
  396. overlay,
  397. label,
  398. (x1, max(y1 - 10, 0)),
  399. cv2.FONT_HERSHEY_SIMPLEX,
  400. 0.5,
  401. color255,
  402. 1,
  403. cv2.LINE_AA,
  404. )
  405. # Overlay frame index at the top-left corner
  406. if frame_idx is not None:
  407. cv2.putText(
  408. overlay,
  409. f"Frame {frame_idx}",
  410. (10, 30),
  411. cv2.FONT_HERSHEY_SIMPLEX,
  412. 1.0,
  413. (255, 255, 255),
  414. 2,
  415. cv2.LINE_AA,
  416. )
  417. return overlay
  418. def save_masklet_video(video_frames, outputs, out_path, alpha=0.5, fps=10):
  419. # Each outputs dict has keys: "out_boxes_xywh", "out_probs", "out_obj_ids", "out_binary_masks"
  420. # video_frames: list of video frame data, same length as outputs_list
  421. # Read first frame to get size
  422. first_img = load_frame(video_frames[0])
  423. height, width = first_img.shape[:2]
  424. if first_img.dtype == np.float32 or first_img.max() <= 1.0:
  425. first_img = (first_img * 255).astype(np.uint8)
  426. # Use 'mp4v' for best compatibility with VSCode playback (.mp4 files)
  427. fourcc = cv2.VideoWriter_fourcc(*"mp4v")
  428. writer = cv2.VideoWriter("temp.mp4", fourcc, fps, (width, height))
  429. outputs_list = [
  430. (video_frames[frame_idx], frame_idx, outputs[frame_idx])
  431. for frame_idx in sorted(outputs.keys())
  432. ]
  433. for frame, frame_idx, frame_outputs in tqdm(outputs_list):
  434. img = load_frame(frame)
  435. overlay = render_masklet_frame(
  436. img, frame_outputs, frame_idx=frame_idx, alpha=alpha
  437. )
  438. writer.write(cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
  439. writer.release()
  440. # Re-encode the video for VSCode compatibility using ffmpeg
  441. subprocess.run(["ffmpeg", "-y", "-i", "temp.mp4", out_path])
  442. print(f"Re-encoded video saved to {out_path}")
  443. os.remove("temp.mp4") # Clean up temporary file
  444. def save_masklet_image(frame, outputs, out_path, alpha=0.5, frame_idx=None):
  445. """
  446. Save a single image with masklet overlays.
  447. """
  448. img = load_frame(frame)
  449. overlay = render_masklet_frame(img, outputs, frame_idx=frame_idx, alpha=alpha)
  450. Image.fromarray(overlay).save(out_path)
  451. print(f"Overlay image saved to {out_path}")
  452. def prepare_masks_for_visualization(frame_to_output):
  453. # frame_to_obj_masks --> {frame_idx: {'output_probs': np.array, `out_obj_ids`: np.array, `out_binary_masks`: np.array}}
  454. for frame_idx, out in frame_to_output.items():
  455. _processed_out = {}
  456. for idx, obj_id in enumerate(out["out_obj_ids"].tolist()):
  457. if out["out_binary_masks"][idx].any():
  458. _processed_out[obj_id] = out["out_binary_masks"][idx]
  459. frame_to_output[frame_idx] = _processed_out
  460. return frame_to_output
  461. def convert_coco_to_masklet_format(
  462. annotations, img_info, is_prediction=False, score_threshold=0.5
  463. ):
  464. """
  465. Convert COCO format annotations to format expected by render_masklet_frame
  466. """
  467. outputs = {
  468. "out_boxes_xywh": [],
  469. "out_probs": [],
  470. "out_obj_ids": [],
  471. "out_binary_masks": [],
  472. }
  473. img_h, img_w = img_info["height"], img_info["width"]
  474. for idx, ann in enumerate(annotations):
  475. # Get bounding box in relative XYWH format
  476. if "bbox" in ann:
  477. bbox = ann["bbox"]
  478. if max(bbox) > 1.0: # Convert absolute to relative coordinates
  479. bbox = [
  480. bbox[0] / img_w,
  481. bbox[1] / img_h,
  482. bbox[2] / img_w,
  483. bbox[3] / img_h,
  484. ]
  485. else:
  486. mask = mask_utils.decode(ann["segmentation"])
  487. rows = np.any(mask, axis=1)
  488. cols = np.any(mask, axis=0)
  489. if np.any(rows) and np.any(cols):
  490. rmin, rmax = np.where(rows)[0][[0, -1]]
  491. cmin, cmax = np.where(cols)[0][[0, -1]]
  492. # Convert to relative XYWH
  493. bbox = [
  494. cmin / img_w,
  495. rmin / img_h,
  496. (cmax - cmin + 1) / img_w,
  497. (rmax - rmin + 1) / img_h,
  498. ]
  499. else:
  500. bbox = [0, 0, 0, 0]
  501. outputs["out_boxes_xywh"].append(bbox)
  502. # Get probability/score
  503. if is_prediction:
  504. prob = ann["score"]
  505. else:
  506. prob = 1.0 # GT has no probability
  507. outputs["out_probs"].append(prob)
  508. outputs["out_obj_ids"].append(idx)
  509. mask = mask_utils.decode(ann["segmentation"])
  510. mask = (mask > score_threshold).astype(np.uint8)
  511. outputs["out_binary_masks"].append(mask)
  512. return outputs
  513. def save_side_by_side_visualization(img, gt_anns, pred_anns, noun_phrase):
  514. """
  515. Create side-by-side visualization of GT and predictions
  516. """
  517. # Create side-by-side visualization
  518. fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7))
  519. main_title = f"Noun phrase: '{noun_phrase}'"
  520. fig.suptitle(main_title, fontsize=16, fontweight="bold")
  521. gt_overlay = render_masklet_frame(img, gt_anns, alpha=0.5)
  522. ax1.imshow(gt_overlay)
  523. ax1.set_title("Ground Truth", fontsize=14, fontweight="bold")
  524. ax1.axis("off")
  525. pred_overlay = render_masklet_frame(img, pred_anns, alpha=0.5)
  526. ax2.imshow(pred_overlay)
  527. ax2.set_title("Predictions", fontsize=14, fontweight="bold")
  528. ax2.axis("off")
  529. plt.subplots_adjust(top=0.88)
  530. plt.tight_layout()
  531. def bitget(val, idx):
  532. return (val >> idx) & 1
  533. def pascal_color_map():
  534. colormap = np.zeros((512, 3), dtype=int)
  535. ind = np.arange(512, dtype=int)
  536. for shift in reversed(list(range(8))):
  537. for channel in range(3):
  538. colormap[:, channel] |= bitget(ind, channel) << shift
  539. ind >>= 3
  540. return colormap.astype(np.uint8)
  541. def draw_masks_to_frame(
  542. frame: np.ndarray, masks: np.ndarray, colors: np.ndarray
  543. ) -> np.ndarray:
  544. masked_frame = frame
  545. for mask, color in zip(masks, colors):
  546. curr_masked_frame = np.where(mask[..., None], color, masked_frame)
  547. masked_frame = cv2.addWeighted(masked_frame, 0.75, curr_masked_frame, 0.25, 0)
  548. if int(cv2.__version__[0]) > 3:
  549. contours, _ = cv2.findContours(
  550. np.array(mask, dtype=np.uint8).copy(),
  551. cv2.RETR_TREE,
  552. cv2.CHAIN_APPROX_NONE,
  553. )
  554. else:
  555. _, contours, _ = cv2.findContours(
  556. np.array(mask, dtype=np.uint8).copy(),
  557. cv2.RETR_TREE,
  558. cv2.CHAIN_APPROX_NONE,
  559. )
  560. cv2.drawContours(
  561. masked_frame, contours, -1, (255, 255, 255), 7
  562. ) # White outer contour
  563. cv2.drawContours(
  564. masked_frame, contours, -1, (0, 0, 0), 5
  565. ) # Black middle contour
  566. cv2.drawContours(
  567. masked_frame, contours, -1, color.tolist(), 3
  568. ) # Original color inner contour
  569. return masked_frame
  570. def get_annot_df(file_path: str):
  571. with open(file_path, "r") as f:
  572. data = json.load(f)
  573. dfs = {}
  574. for k, v in data.items():
  575. if k in ("info", "licenses"):
  576. dfs[k] = v
  577. continue
  578. df = pd.DataFrame(v)
  579. dfs[k] = df
  580. return dfs
  581. def get_annot_dfs(file_list: list[str]):
  582. dfs = {}
  583. for annot_file in tqdm(file_list):
  584. dataset_name = Path(annot_file).stem
  585. dfs[dataset_name] = get_annot_df(annot_file)
  586. return dfs
  587. def get_media_dir(media_dir: str, dataset: str):
  588. if dataset in ["saco_veval_sav_test", "saco_veval_sav_val"]:
  589. return os.path.join(media_dir, "saco_sav", "JPEGImages_24fps")
  590. elif dataset in ["saco_veval_yt1b_test", "saco_veval_yt1b_val"]:
  591. return os.path.join(media_dir, "saco_yt1b", "JPEGImages_6fps")
  592. elif dataset in ["saco_veval_smartglasses_test", "saco_veval_smartglasses_val"]:
  593. return os.path.join(media_dir, "saco_sg", "JPEGImages_6fps")
  594. elif dataset == "sa_fari_test":
  595. return os.path.join(media_dir, "sa_fari", "JPEGImages_6fps")
  596. else:
  597. raise ValueError(f"Dataset {dataset} not found")
  598. def get_all_annotations_for_frame(
  599. dataset_df: pd.DataFrame, video_id: int, frame_idx: int, data_dir: str, dataset: str
  600. ):
  601. media_dir = os.path.join(data_dir, "media")
  602. # Load the annotation and video data
  603. annot_df = dataset_df["annotations"]
  604. video_df = dataset_df["videos"]
  605. # Get the frame
  606. video_df_current = video_df[video_df.id == video_id]
  607. assert len(video_df_current) == 1, (
  608. f"Expected 1 video row, got {len(video_df_current)}"
  609. )
  610. video_row = video_df_current.iloc[0]
  611. file_name = video_row.file_names[frame_idx]
  612. file_path = os.path.join(
  613. get_media_dir(media_dir=media_dir, dataset=dataset), file_name
  614. )
  615. frame = cv2.imread(file_path)
  616. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  617. # Get the masks and noun phrases annotated in this video in this frame
  618. annot_df_current_video = annot_df[annot_df.video_id == video_id]
  619. if len(annot_df_current_video) == 0:
  620. print(f"No annotations found for video_id {video_id}")
  621. return frame, None, None
  622. else:
  623. empty_mask = np.zeros(frame.shape[:2], dtype=np.uint8)
  624. mask_np_pairs = annot_df_current_video.apply(
  625. lambda row: (
  626. (
  627. mask_utils.decode(row.segmentations[frame_idx])
  628. if row.segmentations[frame_idx]
  629. else empty_mask
  630. ),
  631. row.noun_phrase,
  632. ),
  633. axis=1,
  634. )
  635. # sort based on noun_phrases
  636. mask_np_pairs = sorted(mask_np_pairs, key=lambda x: x[1])
  637. masks, noun_phrases = zip(*mask_np_pairs)
  638. return frame, masks, noun_phrases
  639. def visualize_prompt_overlay(
  640. frame_idx,
  641. video_frames,
  642. title="Prompt Visualization",
  643. text_prompt=None,
  644. point_prompts=None,
  645. point_labels=None,
  646. bounding_boxes=None,
  647. box_labels=None,
  648. obj_id=None,
  649. ):
  650. """Simple prompt visualization function"""
  651. img = Image.fromarray(load_frame(video_frames[frame_idx]))
  652. fig, ax = plt.subplots(1, figsize=(6, 4))
  653. ax.imshow(img)
  654. img_w, img_h = img.size
  655. if text_prompt:
  656. ax.text(
  657. 0.02,
  658. 0.98,
  659. f'Text: "{text_prompt}"',
  660. transform=ax.transAxes,
  661. fontsize=12,
  662. color="white",
  663. weight="bold",
  664. bbox=dict(boxstyle="round,pad=0.3", facecolor="red", alpha=0.7),
  665. verticalalignment="top",
  666. )
  667. if point_prompts:
  668. for i, point in enumerate(point_prompts):
  669. x, y = point
  670. # Convert relative to absolute coordinates
  671. x_img, y_img = x * img_w, y * img_h
  672. # Use different colors for positive/negative points
  673. if point_labels and len(point_labels) > i:
  674. color = "green" if point_labels[i] == 1 else "red"
  675. marker = "o" if point_labels[i] == 1 else "x"
  676. else:
  677. color = "green"
  678. marker = "o"
  679. ax.plot(
  680. x_img,
  681. y_img,
  682. marker=marker,
  683. color=color,
  684. markersize=10,
  685. markeredgewidth=2,
  686. markeredgecolor="white",
  687. )
  688. ax.text(
  689. x_img + 5,
  690. y_img - 5,
  691. f"P{i + 1}",
  692. color=color,
  693. fontsize=10,
  694. weight="bold",
  695. bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.8),
  696. )
  697. if bounding_boxes:
  698. for i, box in enumerate(bounding_boxes):
  699. x, y, w, h = box
  700. # Convert relative to absolute coordinates
  701. x_img, y_img = x * img_w, y * img_h
  702. w_img, h_img = w * img_w, h * img_h
  703. # Use different colors for positive/negative boxes
  704. if box_labels and len(box_labels) > i:
  705. color = "green" if box_labels[i] == 1 else "red"
  706. else:
  707. color = "green"
  708. rect = patches.Rectangle(
  709. (x_img, y_img),
  710. w_img,
  711. h_img,
  712. linewidth=2,
  713. edgecolor=color,
  714. facecolor="none",
  715. )
  716. ax.add_patch(rect)
  717. ax.text(
  718. x_img,
  719. y_img - 5,
  720. f"B{i + 1}",
  721. color=color,
  722. fontsize=10,
  723. weight="bold",
  724. bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.8),
  725. )
  726. # Add object ID info if provided
  727. if obj_id is not None:
  728. ax.text(
  729. 0.02,
  730. 0.02,
  731. f"Object ID: {obj_id}",
  732. transform=ax.transAxes,
  733. fontsize=10,
  734. color="white",
  735. weight="bold",
  736. bbox=dict(boxstyle="round,pad=0.3", facecolor="blue", alpha=0.7),
  737. verticalalignment="bottom",
  738. )
  739. ax.set_title(title)
  740. ax.axis("off")
  741. plt.tight_layout()
  742. plt.show()
  743. def plot_results(img, results):
  744. plt.figure(figsize=(12, 8))
  745. plt.imshow(img)
  746. nb_objects = len(results["scores"])
  747. print(f"found {nb_objects} object(s)")
  748. for i in range(nb_objects):
  749. color = COLORS[i % len(COLORS)]
  750. plot_mask(results["masks"][i].squeeze(0).cpu(), color=color)
  751. w, h = img.size
  752. prob = results["scores"][i].item()
  753. plot_bbox(
  754. h,
  755. w,
  756. results["boxes"][i].cpu(),
  757. text=f"(id={i}, {prob=:.2f})",
  758. box_format="XYXY",
  759. color=color,
  760. relative_coords=False,
  761. )
  762. def single_visualization(img, anns, title):
  763. """
  764. Create a single image visualization with overlays.
  765. """
  766. fig, ax = plt.subplots(figsize=(7, 7))
  767. fig.suptitle(title, fontsize=16, fontweight="bold")
  768. overlay = render_masklet_frame(img, anns, alpha=0.5)
  769. ax.imshow(overlay)
  770. ax.axis("off")
  771. plt.tight_layout()
  772. def show_mask(mask, ax, obj_id=None, random_color=False):
  773. if random_color:
  774. color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
  775. else:
  776. cmap = plt.get_cmap("tab10")
  777. cmap_idx = 0 if obj_id is None else obj_id
  778. color = np.array([*cmap(cmap_idx)[:3], 0.6])
  779. h, w = mask.shape[-2:]
  780. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  781. ax.imshow(mask_image)
  782. def show_box(box, ax):
  783. x0, y0 = box[0], box[1]
  784. w, h = box[2] - box[0], box[3] - box[1]
  785. ax.add_patch(
  786. plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
  787. )
  788. def show_points(coords, labels, ax, marker_size=375):
  789. pos_points = coords[labels == 1]
  790. neg_points = coords[labels == 0]
  791. ax.scatter(
  792. pos_points[:, 0],
  793. pos_points[:, 1],
  794. color="green",
  795. marker="*",
  796. s=marker_size,
  797. edgecolor="white",
  798. linewidth=1.25,
  799. )
  800. ax.scatter(
  801. neg_points[:, 0],
  802. neg_points[:, 1],
  803. color="red",
  804. marker="*",
  805. s=marker_size,
  806. edgecolor="white",
  807. linewidth=1.25,
  808. )
  809. def load_frame(frame):
  810. if isinstance(frame, np.ndarray):
  811. img = frame
  812. elif isinstance(frame, Image.Image):
  813. img = np.array(frame)
  814. elif isinstance(frame, str) and os.path.isfile(frame):
  815. img = plt.imread(frame)
  816. else:
  817. raise ValueError(f"Invalid video frame type: {type(frame)=}")
  818. return img