| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- # Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
- from typing import Any, Dict, List, Optional, Tuple
- import numpy as np
- import torch
- from torchvision.ops.boxes import batched_nms, box_area # type: ignore
- from sam2.modeling.sam2_base import SAM2Base
- from sam2.sam2_image_predictor import SAM2ImagePredictor
- from sam2.utils.amg import (
- area_from_rle,
- batch_iterator,
- batched_mask_to_box,
- box_xyxy_to_xywh,
- build_all_layer_point_grids,
- calculate_stability_score,
- coco_encode_rle,
- generate_crop_boxes,
- is_box_near_crop_edge,
- mask_to_rle_pytorch,
- MaskData,
- remove_small_regions,
- rle_to_mask,
- uncrop_boxes_xyxy,
- uncrop_masks,
- uncrop_points,
- )
- class SAM2AutomaticMaskGenerator:
- def __init__(
- self,
- model: SAM2Base,
- points_per_side: Optional[int] = 32,
- points_per_batch: int = 64,
- pred_iou_thresh: float = 0.8,
- stability_score_thresh: float = 0.95,
- stability_score_offset: float = 1.0,
- mask_threshold: float = 0.0,
- box_nms_thresh: float = 0.7,
- crop_n_layers: int = 0,
- crop_nms_thresh: float = 0.7,
- crop_overlap_ratio: float = 512 / 1500,
- crop_n_points_downscale_factor: int = 1,
- point_grids: Optional[List[np.ndarray]] = None,
- min_mask_region_area: int = 0,
- output_mode: str = "binary_mask",
- use_m2m: bool = False,
- multimask_output: bool = True,
- ) -> None:
- """
- Using a SAM 2 model, generates masks for the entire image.
- Generates a grid of point prompts over the image, then filters
- low quality and duplicate masks. The default settings are chosen
- for SAM 2 with a HieraL backbone.
- Arguments:
- model (Sam): The SAM 2 model to use for mask prediction.
- points_per_side (int or None): The number of points to be sampled
- along one side of the image. The total number of points is
- points_per_side**2. If None, 'point_grids' must provide explicit
- point sampling.
- points_per_batch (int): Sets the number of points run simultaneously
- by the model. Higher numbers may be faster but use more GPU memory.
- pred_iou_thresh (float): A filtering threshold in [0,1], using the
- model's predicted mask quality.
- stability_score_thresh (float): A filtering threshold in [0,1], using
- the stability of the mask under changes to the cutoff used to binarize
- the model's mask predictions.
- stability_score_offset (float): The amount to shift the cutoff when
- calculated the stability score.
- mask_threshold (float): Threshold for binarizing the mask logits
- box_nms_thresh (float): The box IoU cutoff used by non-maximal
- suppression to filter duplicate masks.
- crop_n_layers (int): If >0, mask prediction will be run again on
- crops of the image. Sets the number of layers to run, where each
- layer has 2**i_layer number of image crops.
- crop_nms_thresh (float): The box IoU cutoff used by non-maximal
- suppression to filter duplicate masks between different crops.
- crop_overlap_ratio (float): Sets the degree to which crops overlap.
- In the first crop layer, crops will overlap by this fraction of
- the image length. Later layers with more crops scale down this overlap.
- crop_n_points_downscale_factor (int): The number of points-per-side
- sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
- point_grids (list(np.ndarray) or None): A list over explicit grids
- of points used for sampling, normalized to [0,1]. The nth grid in the
- list is used in the nth crop layer. Exclusive with points_per_side.
- min_mask_region_area (int): If >0, postprocessing will be applied
- to remove disconnected regions and holes in masks with area smaller
- than min_mask_region_area. Requires opencv.
- output_mode (str): The form masks are returned in. Can be 'binary_mask',
- 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
- For large resolutions, 'binary_mask' may consume large amounts of
- memory.
- use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
- multimask_output (bool): Whether to output multimask at each point of the grid.
- """
- assert (points_per_side is None) != (
- point_grids is None
- ), "Exactly one of points_per_side or point_grid must be provided."
- if points_per_side is not None:
- self.point_grids = build_all_layer_point_grids(
- points_per_side,
- crop_n_layers,
- crop_n_points_downscale_factor,
- )
- elif point_grids is not None:
- self.point_grids = point_grids
- else:
- raise ValueError("Can't have both points_per_side and point_grid be None.")
- assert output_mode in [
- "binary_mask",
- "uncompressed_rle",
- "coco_rle",
- ], f"Unknown output_mode {output_mode}."
- if output_mode == "coco_rle":
- try:
- from pycocotools import mask as mask_utils # type: ignore # noqa: F401
- except ImportError as e:
- print("Please install pycocotools")
- raise e
- self.predictor = SAM2ImagePredictor(
- model,
- max_hole_area=min_mask_region_area,
- max_sprinkle_area=min_mask_region_area,
- )
- self.points_per_batch = points_per_batch
- self.pred_iou_thresh = pred_iou_thresh
- self.stability_score_thresh = stability_score_thresh
- self.stability_score_offset = stability_score_offset
- self.mask_threshold = mask_threshold
- self.box_nms_thresh = box_nms_thresh
- self.crop_n_layers = crop_n_layers
- self.crop_nms_thresh = crop_nms_thresh
- self.crop_overlap_ratio = crop_overlap_ratio
- self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
- self.min_mask_region_area = min_mask_region_area
- self.output_mode = output_mode
- self.use_m2m = use_m2m
- self.multimask_output = multimask_output
- @torch.no_grad()
- def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
- """
- Generates masks for the given image.
- Arguments:
- image (np.ndarray): The image to generate masks for, in HWC uint8 format.
- Returns:
- list(dict(str, any)): A list over records for masks. Each record is
- a dict containing the following keys:
- segmentation (dict(str, any) or np.ndarray): The mask. If
- output_mode='binary_mask', is an array of shape HW. Otherwise,
- is a dictionary containing the RLE.
- bbox (list(float)): The box around the mask, in XYWH format.
- area (int): The area in pixels of the mask.
- predicted_iou (float): The model's own prediction of the mask's
- quality. This is filtered by the pred_iou_thresh parameter.
- point_coords (list(list(float))): The point coordinates input
- to the model to generate this mask.
- stability_score (float): A measure of the mask's quality. This
- is filtered on using the stability_score_thresh parameter.
- crop_box (list(float)): The crop of the image used to generate
- the mask, given in XYWH format.
- """
- # Generate masks
- mask_data = self._generate_masks(image)
- # Encode masks
- if self.output_mode == "coco_rle":
- mask_data["segmentations"] = [
- coco_encode_rle(rle) for rle in mask_data["rles"]
- ]
- elif self.output_mode == "binary_mask":
- mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
- else:
- mask_data["segmentations"] = mask_data["rles"]
- # Write mask records
- curr_anns = []
- for idx in range(len(mask_data["segmentations"])):
- ann = {
- "segmentation": mask_data["segmentations"][idx],
- "area": area_from_rle(mask_data["rles"][idx]),
- "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
- "predicted_iou": mask_data["iou_preds"][idx].item(),
- "point_coords": [mask_data["points"][idx].tolist()],
- "stability_score": mask_data["stability_score"][idx].item(),
- "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
- }
- curr_anns.append(ann)
- return curr_anns
- def _generate_masks(self, image: np.ndarray) -> MaskData:
- orig_size = image.shape[:2]
- crop_boxes, layer_idxs = generate_crop_boxes(
- orig_size, self.crop_n_layers, self.crop_overlap_ratio
- )
- # Iterate over image crops
- data = MaskData()
- for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
- crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
- data.cat(crop_data)
- # Remove duplicate masks between crops
- if len(crop_boxes) > 1:
- # Prefer masks from smaller crops
- scores = 1 / box_area(data["crop_boxes"])
- scores = scores.to(data["boxes"].device)
- keep_by_nms = batched_nms(
- data["boxes"].float(),
- scores,
- torch.zeros_like(data["boxes"][:, 0]), # categories
- iou_threshold=self.crop_nms_thresh,
- )
- data.filter(keep_by_nms)
- data.to_numpy()
- return data
- def _process_crop(
- self,
- image: np.ndarray,
- crop_box: List[int],
- crop_layer_idx: int,
- orig_size: Tuple[int, ...],
- ) -> MaskData:
- # Crop the image and calculate embeddings
- x0, y0, x1, y1 = crop_box
- cropped_im = image[y0:y1, x0:x1, :]
- cropped_im_size = cropped_im.shape[:2]
- self.predictor.set_image(cropped_im)
- # Get points for this crop
- points_scale = np.array(cropped_im_size)[None, ::-1]
- points_for_image = self.point_grids[crop_layer_idx] * points_scale
- # Generate masks for this crop in batches
- data = MaskData()
- for (points,) in batch_iterator(self.points_per_batch, points_for_image):
- batch_data = self._process_batch(
- points, cropped_im_size, crop_box, orig_size, normalize=True
- )
- data.cat(batch_data)
- del batch_data
- self.predictor.reset_predictor()
- # Remove duplicates within this crop.
- keep_by_nms = batched_nms(
- data["boxes"].float(),
- data["iou_preds"],
- torch.zeros_like(data["boxes"][:, 0]), # categories
- iou_threshold=self.box_nms_thresh,
- )
- data.filter(keep_by_nms)
- # Return to the original image frame
- data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
- data["points"] = uncrop_points(data["points"], crop_box)
- data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
- return data
- def _process_batch(
- self,
- points: np.ndarray,
- im_size: Tuple[int, ...],
- crop_box: List[int],
- orig_size: Tuple[int, ...],
- normalize=False,
- ) -> MaskData:
- orig_h, orig_w = orig_size
- # Run model on this batch
- points = torch.as_tensor(points, device=self.predictor.device)
- in_points = self.predictor._transforms.transform_coords(
- points, normalize=normalize, orig_hw=im_size
- )
- in_labels = torch.ones(
- in_points.shape[0], dtype=torch.int, device=in_points.device
- )
- masks, iou_preds, low_res_masks = self.predictor._predict(
- in_points[:, None, :],
- in_labels[:, None],
- multimask_output=self.multimask_output,
- return_logits=True,
- )
- # Serialize predictions and store in MaskData
- data = MaskData(
- masks=masks.flatten(0, 1),
- iou_preds=iou_preds.flatten(0, 1),
- points=points.repeat_interleave(masks.shape[1], dim=0),
- low_res_masks=low_res_masks.flatten(0, 1),
- )
- del masks
- if not self.use_m2m:
- # Filter by predicted IoU
- if self.pred_iou_thresh > 0.0:
- keep_mask = data["iou_preds"] > self.pred_iou_thresh
- data.filter(keep_mask)
- # Calculate and filter by stability score
- data["stability_score"] = calculate_stability_score(
- data["masks"], self.mask_threshold, self.stability_score_offset
- )
- if self.stability_score_thresh > 0.0:
- keep_mask = data["stability_score"] >= self.stability_score_thresh
- data.filter(keep_mask)
- else:
- # One step refinement using previous mask predictions
- in_points = self.predictor._transforms.transform_coords(
- data["points"], normalize=normalize, orig_hw=im_size
- )
- labels = torch.ones(
- in_points.shape[0], dtype=torch.int, device=in_points.device
- )
- masks, ious = self.refine_with_m2m(
- in_points, labels, data["low_res_masks"], self.points_per_batch
- )
- data["masks"] = masks.squeeze(1)
- data["iou_preds"] = ious.squeeze(1)
- if self.pred_iou_thresh > 0.0:
- keep_mask = data["iou_preds"] > self.pred_iou_thresh
- data.filter(keep_mask)
- data["stability_score"] = calculate_stability_score(
- data["masks"], self.mask_threshold, self.stability_score_offset
- )
- if self.stability_score_thresh > 0.0:
- keep_mask = data["stability_score"] >= self.stability_score_thresh
- data.filter(keep_mask)
- # Threshold masks and calculate boxes
- data["masks"] = data["masks"] > self.mask_threshold
- data["boxes"] = batched_mask_to_box(data["masks"])
- # Filter boxes that touch crop boundaries
- keep_mask = ~is_box_near_crop_edge(
- data["boxes"], crop_box, [0, 0, orig_w, orig_h]
- )
- if not torch.all(keep_mask):
- data.filter(keep_mask)
- # Compress to RLE
- data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
- data["rles"] = mask_to_rle_pytorch(data["masks"])
- del data["masks"]
- return data
- @staticmethod
- def postprocess_small_regions(
- mask_data: MaskData, min_area: int, nms_thresh: float
- ) -> MaskData:
- """
- Removes small disconnected regions and holes in masks, then reruns
- box NMS to remove any new duplicates.
- Edits mask_data in place.
- Requires open-cv as a dependency.
- """
- if len(mask_data["rles"]) == 0:
- return mask_data
- # Filter small disconnected regions and holes
- new_masks = []
- scores = []
- for rle in mask_data["rles"]:
- mask = rle_to_mask(rle)
- mask, changed = remove_small_regions(mask, min_area, mode="holes")
- unchanged = not changed
- mask, changed = remove_small_regions(mask, min_area, mode="islands")
- unchanged = unchanged and not changed
- new_masks.append(torch.as_tensor(mask).unsqueeze(0))
- # Give score=0 to changed masks and score=1 to unchanged masks
- # so NMS will prefer ones that didn't need postprocessing
- scores.append(float(unchanged))
- # Recalculate boxes and remove any new duplicates
- masks = torch.cat(new_masks, dim=0)
- boxes = batched_mask_to_box(masks)
- keep_by_nms = batched_nms(
- boxes.float(),
- torch.as_tensor(scores),
- torch.zeros_like(boxes[:, 0]), # categories
- iou_threshold=nms_thresh,
- )
- # Only recalculate RLEs for masks that have changed
- for i_mask in keep_by_nms:
- if scores[i_mask] == 0.0:
- mask_torch = masks[i_mask].unsqueeze(0)
- mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
- mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
- mask_data.filter(keep_by_nms)
- return mask_data
- def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
- new_masks = []
- new_iou_preds = []
- for cur_points, cur_point_labels, low_res_mask in batch_iterator(
- points_per_batch, points, point_labels, low_res_masks
- ):
- best_masks, best_iou_preds, _ = self.predictor._predict(
- cur_points[:, None, :],
- cur_point_labels[:, None],
- mask_input=low_res_mask[:, None, :],
- multimask_output=False,
- return_logits=True,
- )
- new_masks.append(best_masks)
- new_iou_preds.append(best_iou_preds)
- masks = torch.cat(new_masks, dim=0)
- return masks, torch.cat(new_iou_preds, dim=0)
|