mask_overlap_removal.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from typing import Dict, List
  4. import numpy as np
  5. import torch
  6. try:
  7. from pycocotools import mask as mask_utils
  8. except Exception:
  9. mask_utils = None
  10. def mask_intersection(
  11. masks1: torch.Tensor, masks2: torch.Tensor, block_size: int = 16
  12. ) -> torch.Tensor:
  13. assert masks1.shape[1:] == masks2.shape[1:]
  14. assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
  15. N, M = masks1.shape[0], masks2.shape[0]
  16. out = torch.zeros(N, M, device=masks1.device, dtype=torch.long)
  17. for i in range(0, N, block_size):
  18. for j in range(0, M, block_size):
  19. a = masks1[i : i + block_size]
  20. b = masks2[j : j + block_size]
  21. inter = (a[:, None] & b[None, :]).flatten(-2).sum(-1)
  22. out[i : i + block_size, j : j + block_size] = inter
  23. return out
  24. def mask_iom(masks1: torch.Tensor, masks2: torch.Tensor) -> torch.Tensor:
  25. assert masks1.shape[1:] == masks2.shape[1:]
  26. assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
  27. inter = mask_intersection(masks1, masks2)
  28. area1 = masks1.flatten(-2).sum(-1) # (N,)
  29. area2 = masks2.flatten(-2).sum(-1) # (M,)
  30. min_area = torch.min(area1[:, None], area2[None, :]).clamp_min(1)
  31. return inter.float() / (min_area.float() + 1e-8)
  32. def _decode_single_mask(mask_repr, h: int, w: int) -> np.ndarray:
  33. if isinstance(mask_repr, (list, tuple, np.ndarray)):
  34. arr = np.array(mask_repr)
  35. if arr.ndim != 2:
  36. raise ValueError("Mask array must be 2D (H, W).")
  37. return (arr > 0).astype(np.uint8)
  38. if mask_utils is None:
  39. raise ImportError(
  40. "pycocotools is required to decode RLE mask strings. pip install pycocotools"
  41. )
  42. if not isinstance(mask_repr, (str, bytes)):
  43. raise ValueError("Unsupported mask representation type for RLE decode.")
  44. rle = {
  45. "counts": mask_repr if isinstance(mask_repr, (str, bytes)) else str(mask_repr),
  46. "size": [h, w],
  47. }
  48. decoded = mask_utils.decode(rle)
  49. if decoded.ndim == 3:
  50. decoded = decoded[:, :, 0]
  51. return (decoded > 0).astype(np.uint8)
  52. def _decode_masks_to_torch_bool(pred_masks: List, h: int, w: int) -> torch.Tensor:
  53. bin_masks = [_decode_single_mask(m, h, w) for m in pred_masks]
  54. masks_np = np.stack(bin_masks, axis=0).astype(np.uint8) # (N, H, W)
  55. return torch.from_numpy(masks_np > 0)
  56. def remove_overlapping_masks(sample: Dict, iom_thresh: float = 0.3) -> Dict:
  57. """
  58. Greedy keep: sort by score desc; keep a mask if IoM to all kept masks <= threshold.
  59. If pred_masks has length 0 or 1, returns sample unchanged (no extra keys).
  60. """
  61. # Basic presence checks
  62. if "pred_masks" not in sample or not isinstance(sample["pred_masks"], list):
  63. return sample # nothing to do / preserve as-is
  64. pred_masks = sample["pred_masks"]
  65. N = len(pred_masks)
  66. # --- Early exit: 0 or 1 mask -> do NOT modify the JSON at all ---
  67. if N <= 1:
  68. return sample
  69. # From here on we have at least 2 masks
  70. h = int(sample["orig_img_h"])
  71. w = int(sample["orig_img_w"])
  72. pred_scores = sample.get("pred_scores", [1.0] * N) # fallback if scores missing
  73. pred_boxes = sample.get("pred_boxes", None)
  74. assert N == len(pred_scores), "pred_masks and pred_scores must have same length"
  75. if pred_boxes is not None:
  76. assert N == len(pred_boxes), "pred_masks and pred_boxes must have same length"
  77. masks_bool = _decode_masks_to_torch_bool(pred_masks, h, w) # (N, H, W)
  78. order = sorted(range(N), key=lambda i: float(pred_scores[i]), reverse=True)
  79. kept_idx: List[int] = []
  80. kept_masks: List[torch.Tensor] = []
  81. for i in order:
  82. cand = masks_bool[i].unsqueeze(0) # (1, H, W)
  83. if len(kept_masks) == 0:
  84. kept_idx.append(i)
  85. kept_masks.append(masks_bool[i])
  86. continue
  87. kept_stack = torch.stack(kept_masks, dim=0) # (K, H, W)
  88. iom_vals = mask_iom(cand, kept_stack).squeeze(0) # (K,)
  89. if torch.any(iom_vals > iom_thresh):
  90. continue # overlaps too much with a higher-scored kept mask
  91. kept_idx.append(i)
  92. kept_masks.append(masks_bool[i])
  93. kept_idx_sorted = sorted(kept_idx)
  94. # Build filtered JSON (this *does* modify fields; only for N>=2 case)
  95. out = dict(sample)
  96. out["pred_masks"] = [pred_masks[i] for i in kept_idx_sorted]
  97. out["pred_scores"] = [pred_scores[i] for i in kept_idx_sorted]
  98. if pred_boxes is not None:
  99. out["pred_boxes"] = [pred_boxes[i] for i in kept_idx_sorted]
  100. out["kept_indices"] = kept_idx_sorted
  101. out["removed_indices"] = [i for i in range(N) if i not in set(kept_idx_sorted)]
  102. out["iom_threshold"] = float(iom_thresh)
  103. return out