| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- from typing import Dict, List
- import numpy as np
- import torch
- try:
- from pycocotools import mask as mask_utils
- except Exception:
- mask_utils = None
- def mask_intersection(
- masks1: torch.Tensor, masks2: torch.Tensor, block_size: int = 16
- ) -> torch.Tensor:
- assert masks1.shape[1:] == masks2.shape[1:]
- assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
- N, M = masks1.shape[0], masks2.shape[0]
- out = torch.zeros(N, M, device=masks1.device, dtype=torch.long)
- for i in range(0, N, block_size):
- for j in range(0, M, block_size):
- a = masks1[i : i + block_size]
- b = masks2[j : j + block_size]
- inter = (a[:, None] & b[None, :]).flatten(-2).sum(-1)
- out[i : i + block_size, j : j + block_size] = inter
- return out
- def mask_iom(masks1: torch.Tensor, masks2: torch.Tensor) -> torch.Tensor:
- assert masks1.shape[1:] == masks2.shape[1:]
- assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
- inter = mask_intersection(masks1, masks2)
- area1 = masks1.flatten(-2).sum(-1) # (N,)
- area2 = masks2.flatten(-2).sum(-1) # (M,)
- min_area = torch.min(area1[:, None], area2[None, :]).clamp_min(1)
- return inter.float() / (min_area.float() + 1e-8)
- def _decode_single_mask(mask_repr, h: int, w: int) -> np.ndarray:
- if isinstance(mask_repr, (list, tuple, np.ndarray)):
- arr = np.array(mask_repr)
- if arr.ndim != 2:
- raise ValueError("Mask array must be 2D (H, W).")
- return (arr > 0).astype(np.uint8)
- if mask_utils is None:
- raise ImportError(
- "pycocotools is required to decode RLE mask strings. pip install pycocotools"
- )
- if not isinstance(mask_repr, (str, bytes)):
- raise ValueError("Unsupported mask representation type for RLE decode.")
- rle = {
- "counts": mask_repr if isinstance(mask_repr, (str, bytes)) else str(mask_repr),
- "size": [h, w],
- }
- decoded = mask_utils.decode(rle)
- if decoded.ndim == 3:
- decoded = decoded[:, :, 0]
- return (decoded > 0).astype(np.uint8)
- def _decode_masks_to_torch_bool(pred_masks: List, h: int, w: int) -> torch.Tensor:
- bin_masks = [_decode_single_mask(m, h, w) for m in pred_masks]
- masks_np = np.stack(bin_masks, axis=0).astype(np.uint8) # (N, H, W)
- return torch.from_numpy(masks_np > 0)
- def remove_overlapping_masks(sample: Dict, iom_thresh: float = 0.3) -> Dict:
- """
- Greedy keep: sort by score desc; keep a mask if IoM to all kept masks <= threshold.
- If pred_masks has length 0 or 1, returns sample unchanged (no extra keys).
- """
- # Basic presence checks
- if "pred_masks" not in sample or not isinstance(sample["pred_masks"], list):
- return sample # nothing to do / preserve as-is
- pred_masks = sample["pred_masks"]
- N = len(pred_masks)
- # --- Early exit: 0 or 1 mask -> do NOT modify the JSON at all ---
- if N <= 1:
- return sample
- # From here on we have at least 2 masks
- h = int(sample["orig_img_h"])
- w = int(sample["orig_img_w"])
- pred_scores = sample.get("pred_scores", [1.0] * N) # fallback if scores missing
- pred_boxes = sample.get("pred_boxes", None)
- assert N == len(pred_scores), "pred_masks and pred_scores must have same length"
- if pred_boxes is not None:
- assert N == len(pred_boxes), "pred_masks and pred_boxes must have same length"
- masks_bool = _decode_masks_to_torch_bool(pred_masks, h, w) # (N, H, W)
- order = sorted(range(N), key=lambda i: float(pred_scores[i]), reverse=True)
- kept_idx: List[int] = []
- kept_masks: List[torch.Tensor] = []
- for i in order:
- cand = masks_bool[i].unsqueeze(0) # (1, H, W)
- if len(kept_masks) == 0:
- kept_idx.append(i)
- kept_masks.append(masks_bool[i])
- continue
- kept_stack = torch.stack(kept_masks, dim=0) # (K, H, W)
- iom_vals = mask_iom(cand, kept_stack).squeeze(0) # (K,)
- if torch.any(iom_vals > iom_thresh):
- continue # overlaps too much with a higher-scored kept mask
- kept_idx.append(i)
- kept_masks.append(masks_bool[i])
- kept_idx_sorted = sorted(kept_idx)
- # Build filtered JSON (this *does* modify fields; only for N>=2 case)
- out = dict(sample)
- out["pred_masks"] = [pred_masks[i] for i in kept_idx_sorted]
- out["pred_scores"] = [pred_scores[i] for i in kept_idx_sorted]
- if pred_boxes is not None:
- out["pred_boxes"] = [pred_boxes[i] for i in kept_idx_sorted]
- out["kept_indices"] = kept_idx_sorted
- out["removed_indices"] = [i for i in range(N) if i not in set(kept_idx_sorted)]
- out["iom_threshold"] = float(iom_thresh)
- return out
|