| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import numpy as np
- import torch
- import torch.nn.functional as F
- from numpy.typing import NDArray
- from sam3.model.edt import edt_triton
- def sample_box_points(
- masks: torch.Tensor,
- noise: float = 0.1, # SAM default
- noise_bound: int = 20, # SAM default
- top_left_label: int = 2,
- bottom_right_label: int = 3,
- ) -> tuple[NDArray, NDArray]:
- """
- Sample a noised version of the top left and bottom right corners of a given `bbox`
- Inputs:
- - masks: [B, 1, H, W] tensor
- - noise: noise as a fraction of box width and height, dtype=float
- - noise_bound: maximum amount of noise (in pure pixels), dtype=int
- Returns:
- - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
- - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
- """
- device = masks.device
- box_coords = mask_to_box(masks)
- B, _, H, W = masks.shape
- box_labels = torch.tensor(
- [top_left_label, bottom_right_label], dtype=torch.int, device=device
- ).repeat(B)
- if noise > 0.0:
- if not isinstance(noise_bound, torch.Tensor):
- noise_bound = torch.tensor(noise_bound, device=device)
- bbox_w = box_coords[..., 2] - box_coords[..., 0]
- bbox_h = box_coords[..., 3] - box_coords[..., 1]
- max_dx = torch.min(bbox_w * noise, noise_bound)
- max_dy = torch.min(bbox_h * noise, noise_bound)
- box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
- box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
- box_coords = box_coords + box_noise
- img_bounds = (
- torch.tensor([W, H, W, H], device=device) - 1
- ) # uncentered pixel coords
- box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
- box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
- box_labels = box_labels.reshape(-1, 2)
- return box_coords, box_labels
- def mask_to_box(masks: torch.Tensor):
- """
- compute bounding box given an input mask
- Inputs:
- - masks: [B, 1, H, W] tensor
- Returns:
- - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
- """
- B, _, h, w = masks.shape
- device = masks.device
- mask_area = masks.sum(dim=(-1, -2))
- xs = torch.arange(w, device=device, dtype=torch.int32)
- ys = torch.arange(h, device=device, dtype=torch.int32)
- grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
- grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
- grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
- min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
- max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
- min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
- max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
- bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
- bbox_coords = torch.where(
- mask_area[..., None] > 0, bbox_coords, torch.zeros_like(bbox_coords)
- )
- return bbox_coords
- def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
- """
- Sample `num_pt` random points (along with their labels) independently from the error regions.
- Inputs:
- - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- - num_pt: int, number of points to sample independently for each of the B error maps
- Outputs:
- - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
- negative clicks
- """
- if pred_masks is None: # if pred_masks is not provided, treat it as empty
- pred_masks = torch.zeros_like(gt_masks)
- assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
- assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
- assert num_pt >= 0
- B, _, H_im, W_im = gt_masks.shape
- device = gt_masks.device
- # false positive region, a new point sampled in this region should have
- # negative label to correct the FP error
- fp_masks = ~gt_masks & pred_masks
- # false negative region, a new point sampled in this region should have
- # positive label to correct the FN error
- fn_masks = gt_masks & ~pred_masks
- # whether the prediction completely match the ground-truth on each mask
- all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
- all_correct = all_correct[..., None, None]
- # channel 0 is FP map, while channel 1 is FN map
- pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
- # sample a negative new click from FP region or a positive new click
- # from FN region, depend on where the maximum falls,
- # and in case the predictions are all correct (no FP or FN), we just
- # sample a negative click from the background region
- pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
- pts_noise[..., 1] *= fn_masks
- pts_idx = pts_noise.flatten(2).argmax(dim=2)
- labels = (pts_idx % 2).to(torch.int32)
- pts_idx = pts_idx // 2
- pts_x = pts_idx % W_im
- pts_y = pts_idx // W_im
- points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
- return points, labels
- def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
- """
- Sample 1 random point (along with its label) from the center of each error region,
- that is, the point with the largest distance to the boundary of each error region.
- This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
- Inputs:
- - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- - padding: if True, pad with boundary of 1 px for distance transform
- Outputs:
- - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
- """
- if pred_masks is None:
- pred_masks = torch.zeros_like(gt_masks)
- assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
- assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
- B, _, H, W = gt_masks.shape
- # false positive region, a new point sampled in this region should have
- # negative label to correct the FP error
- fp_masks = (~gt_masks & pred_masks).squeeze(1)
- # false negative region, a new point sampled in this region should have
- # positive label to correct the FN error
- fn_masks = (gt_masks & ~pred_masks).squeeze(1)
- if padding:
- padded_fp_masks = torch.zeros(
- B, H + 2, W + 2, dtype=fp_masks.dtype, device=fp_masks.device
- )
- padded_fp_masks[:, 1 : H + 1, 1 : W + 1] = fp_masks
- padded_fn_masks = torch.zeros(
- B, H + 2, W + 2, dtype=fp_masks.dtype, device=fp_masks.device
- )
- padded_fn_masks[:, 1 : H + 1, 1 : W + 1] = fn_masks
- else:
- padded_fp_masks = fp_masks
- padded_fn_masks = fn_masks
- fn_mask_dt = edt_triton(padded_fn_masks)
- fp_mask_dt = edt_triton(padded_fp_masks)
- if padding:
- fn_mask_dt = fn_mask_dt[:, 1:-1, 1:-1]
- fp_mask_dt = fp_mask_dt[:, 1:-1, 1:-1]
- fn_max, fn_argmax = fn_mask_dt.reshape(B, -1).max(dim=-1)
- fp_max, fp_argmax = fp_mask_dt.reshape(B, -1).max(dim=-1)
- is_positive = fn_max > fp_max
- chosen = torch.where(is_positive, fn_argmax, fp_argmax)
- points_x = chosen % W
- points_y = chosen // W
- labels = is_positive.long()
- points = torch.stack([points_x, points_y], -1)
- return points.unsqueeze(1), labels.unsqueeze(1)
- def sample_one_point_from_error_center_slow(gt_masks, pred_masks, padding=True):
- """
- Sample 1 random point (along with its label) from the center of each error region,
- that is, the point with the largest distance to the boundary of each error region.
- This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
- Inputs:
- - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- - padding: if True, pad with boundary of 1 px for distance transform
- Outputs:
- - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
- """
- import cv2 # delay OpenCV import to avoid unnecessary dependency
- if pred_masks is None:
- pred_masks = torch.zeros_like(gt_masks)
- assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
- assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
- B, _, _, W_im = gt_masks.shape
- device = gt_masks.device
- # false positive region, a new point sampled in this region should have
- # negative label to correct the FP error
- fp_masks = ~gt_masks & pred_masks
- # false negative region, a new point sampled in this region should have
- # positive label to correct the FN error
- fn_masks = gt_masks & ~pred_masks
- fp_masks = fp_masks.cpu().numpy()
- fn_masks = fn_masks.cpu().numpy()
- points = torch.zeros(B, 1, 2, dtype=torch.float)
- labels = torch.ones(B, 1, dtype=torch.int32)
- for b in range(B):
- fn_mask = fn_masks[b, 0]
- fp_mask = fp_masks[b, 0]
- if padding:
- fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
- fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
- # compute the distance of each point in FN/FP region to its boundary
- fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
- fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
- if padding:
- fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
- fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
- # take the point in FN/FP region with the largest distance to its boundary
- fn_mask_dt_flat = fn_mask_dt.reshape(-1)
- fp_mask_dt_flat = fp_mask_dt.reshape(-1)
- fn_argmax = np.argmax(fn_mask_dt_flat)
- fp_argmax = np.argmax(fp_mask_dt_flat)
- is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
- pt_idx = fn_argmax if is_positive else fp_argmax
- points[b, 0, 0] = pt_idx % W_im # x
- points[b, 0, 1] = pt_idx // W_im # y
- labels[b, 0] = int(is_positive)
- points = points.to(device)
- labels = labels.to(device)
- return points, labels
- def get_next_point(gt_masks, pred_masks, method):
- if method == "uniform":
- return sample_random_points_from_errors(gt_masks, pred_masks)
- elif method == "center":
- return sample_one_point_from_error_center(gt_masks, pred_masks)
- else:
- raise ValueError(f"unknown sampling method {method}")
- def select_closest_cond_frames(
- frame_idx, cond_frame_outputs, max_cond_frame_num, keep_first_cond_frame=False
- ):
- """
- Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
- that are temporally closest to the current frame at `frame_idx`. Here, we take
- - a) the closest conditioning frame before `frame_idx` (if any);
- - b) the closest conditioning frame after `frame_idx` (if any);
- - c) any other temporally closest conditioning frames until reaching a total
- of `max_cond_frame_num` conditioning frames.
- Outputs:
- - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
- - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
- """
- if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
- selected_outputs = cond_frame_outputs
- unselected_outputs = {}
- else:
- assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
- selected_outputs = {}
- if keep_first_cond_frame:
- idx_first = min(
- (t for t in cond_frame_outputs if t < frame_idx), default=None
- )
- if idx_first is None:
- # Maybe we are tracking in reverse
- idx_first = max(
- (t for t in cond_frame_outputs if t > frame_idx), default=None
- )
- if idx_first is not None:
- selected_outputs[idx_first] = cond_frame_outputs[idx_first]
- # the closest conditioning frame before `frame_idx` (if any)
- idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
- if idx_before is not None:
- selected_outputs[idx_before] = cond_frame_outputs[idx_before]
- # the closest conditioning frame after `frame_idx` (if any)
- idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
- if idx_after is not None:
- selected_outputs[idx_after] = cond_frame_outputs[idx_after]
- # add other temporally closest conditioning frames until reaching a total
- # of `max_cond_frame_num` conditioning frames.
- num_remain = max_cond_frame_num - len(selected_outputs)
- inds_remain = sorted(
- (t for t in cond_frame_outputs if t not in selected_outputs),
- key=lambda x: abs(x - frame_idx),
- )[:num_remain]
- selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
- unselected_outputs = {
- t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
- }
- return selected_outputs, unselected_outputs
- def get_1d_sine_pe(pos_inds, dim, temperature=10000):
- """
- Get 1D sine positional embedding as in the original Transformer paper.
- """
- pe_dim = dim // 2
- dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
- dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
- pos_embed = pos_inds.unsqueeze(-1) / dim_t
- pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
- return pos_embed
- def get_best_gt_match_from_multimasks(pred_multimasks, gt_masks, pred_scores=None):
- """
- Get the mask with the best match to GT masks (based on IoU) from pred_multimasks.
- Optionally, use `pred_scores` to break ties in case all IoUs are zeros.
- """
- assert pred_multimasks.ndim == 4 and gt_masks.ndim == 4
- if pred_multimasks.size(1) == 1:
- return pred_multimasks # only a single mask channel, nothing to select
- pred_multimasks_binary = pred_multimasks > 0
- area_i = torch.sum(pred_multimasks_binary & gt_masks, dim=(2, 3)).float()
- area_u = torch.sum(pred_multimasks_binary | gt_masks, dim=(2, 3)).float()
- ious = area_i / torch.clamp(area_u, min=1.0)
- # In case all IoUs are zeros (e.g. because the GT mask is empty), use pred_scores
- # to break ties and select the best mask
- if pred_scores is not None:
- has_nonzero_ious = torch.any(ious > 0).expand_as(ious)
- scores = torch.where(has_nonzero_ious, ious, pred_scores)
- else:
- scores = ious
- # Finally, take the best mask prediction (with the highest score)
- best_scores_inds = torch.argmax(scores, dim=-1)
- batch_inds = torch.arange(scores.size(0), device=scores.device)
- best_pred_mask = pred_multimasks[batch_inds, best_scores_inds].unsqueeze(1)
- return best_pred_mask
- def fill_holes_in_mask_scores(mask, max_area, fill_holes=True, remove_sprinkles=True):
- """
- A post processor to fill small holes in mask scores with area under `max_area`.
- Holes are those small connected components in either background or foreground.
- Note that it relies on the "cc_torch" package to find connected components fast. You can
- install it via the following command (`TORCH_CUDA_ARCH_LIST=8.0` is for A100 GPUs):
- ```
- pip uninstall -y cc_torch; TORCH_CUDA_ARCH_LIST=8.0 9.0 pip install git+https://github.com/ronghanghu/cc_torch
- ```
- Otherwise, it will fallback to a slightly slower triton implementation, or skimage if the tensor is on cpu
- """
- if max_area <= 0:
- return mask # nothing to fill in this case
- if fill_holes:
- # We remove small connected components in background by changing them to foreground
- # with a small positive mask score (0.1).
- mask_bg = mask <= 0
- bg_area_thresh = max_area
- _, areas_bg = _get_connected_components_with_padding(mask_bg)
- small_components_bg = mask_bg & (areas_bg <= bg_area_thresh)
- mask = torch.where(small_components_bg, 0.1, mask)
- if remove_sprinkles:
- # We remove small connected components in foreground by changing them to background
- # with a small negative mask score (-0.1). Here we only remove connected components
- # whose areas are under both `max_area` and half of the entire mask's area. This
- # removes sprinkles while avoids filtering out tiny objects that we want to track.
- mask_fg = mask > 0
- fg_area_thresh = torch.sum(mask_fg, dim=(2, 3), keepdim=True, dtype=torch.int32)
- fg_area_thresh.floor_divide_(2).clamp_(max=max_area)
- _, areas_fg = _get_connected_components_with_padding(mask_fg)
- small_components_fg = mask_fg & (areas_fg <= fg_area_thresh)
- mask = torch.where(small_components_fg, -0.1, mask)
- return mask
- def _get_connected_components_with_padding(mask):
- """Get connected components from masks (possibly padding them to an even size)."""
- from sam3.perflib.connected_components import connected_components
- mask = mask.to(torch.uint8)
- _, _, H, W = mask.shape
- # make sure both height and width are even (to be compatible with cc_torch)
- pad_h = H % 2
- pad_w = W % 2
- if pad_h == 0 and pad_w == 0:
- labels, counts = connected_components(mask)
- else:
- # pad the mask to make its height and width even
- # padding format is (padding_left,padding_right,padding_top,padding_bottom)
- mask_pad = F.pad(mask, (0, pad_w, 0, pad_h), mode="constant", value=0)
- labels, counts = connected_components(mask_pad)
- labels = labels[:, :, :H, :W]
- counts = counts[:, :, :H, :W]
- return labels, counts
|