| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- from typing import Dict, List
- import numpy as np
- import PIL
- import torch
- from sam3.model import box_ops
- from sam3.model.data_misc import FindStage, interpolate
- from torchvision.transforms import v2
- class Sam3Processor:
- """ """
- def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5):
- self.model = model
- self.resolution = resolution
- self.device = device
- self.transform = v2.Compose(
- [
- v2.ToDtype(torch.uint8, scale=True),
- v2.Resize(size=(resolution, resolution)),
- v2.ToDtype(torch.float32, scale=True),
- v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
- ]
- )
- self.confidence_threshold = confidence_threshold
- self.find_stage = FindStage(
- img_ids=torch.tensor([0], device=device, dtype=torch.long),
- text_ids=torch.tensor([0], device=device, dtype=torch.long),
- input_boxes=None,
- input_boxes_mask=None,
- input_boxes_label=None,
- input_points=None,
- input_points_mask=None,
- )
- @torch.inference_mode()
- def set_image(self, image, state=None):
- """Sets the image on which we want to do predictions."""
- if state is None:
- state = {}
- if isinstance(image, PIL.Image.Image):
- width, height = image.size
- elif isinstance(image, (torch.Tensor, np.ndarray)):
- height, width = image.shape[-2:]
- else:
- raise ValueError("Image must be a PIL image or a tensor")
- image = v2.functional.to_image(image).to(self.device)
- image = self.transform(image).unsqueeze(0)
- state["original_height"] = height
- state["original_width"] = width
- state["backbone_out"] = self.model.backbone.forward_image(image)
- inst_interactivity_en = self.model.inst_interactive_predictor is not None
- if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]:
- sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"]
- sam2_backbone_out["backbone_fpn"][0] = (
- self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0(
- sam2_backbone_out["backbone_fpn"][0]
- )
- )
- sam2_backbone_out["backbone_fpn"][1] = (
- self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1(
- sam2_backbone_out["backbone_fpn"][1]
- )
- )
- return state
- @torch.inference_mode()
- def set_image_batch(self, images: List[np.ndarray], state=None):
- """Sets the image batch on which we want to do predictions."""
- if state is None:
- state = {}
- if not isinstance(images, list):
- raise ValueError("Images must be a list of PIL images or tensors")
- assert len(images) > 0, "Images list must not be empty"
- assert isinstance(images[0], PIL.Image.Image), (
- "Images must be a list of PIL images"
- )
- state["original_heights"] = [image.height for image in images]
- state["original_widths"] = [image.width for image in images]
- images = [
- self.transform(v2.functional.to_image(image).to(self.device))
- for image in images
- ]
- images = torch.stack(images, dim=0)
- state["backbone_out"] = self.model.backbone.forward_image(images)
- inst_interactivity_en = self.model.inst_interactive_predictor is not None
- if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]:
- sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"]
- sam2_backbone_out["backbone_fpn"][0] = (
- self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0(
- sam2_backbone_out["backbone_fpn"][0]
- )
- )
- sam2_backbone_out["backbone_fpn"][1] = (
- self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1(
- sam2_backbone_out["backbone_fpn"][1]
- )
- )
- return state
- @torch.inference_mode()
- def set_text_prompt(self, prompt: str, state: Dict):
- """Sets the text prompt and run the inference"""
- if "backbone_out" not in state:
- raise ValueError("You must call set_image before set_text_prompt")
- text_outputs = self.model.backbone.forward_text([prompt], device=self.device)
- # will erase the previous text prompt if any
- state["backbone_out"].update(text_outputs)
- if "geometric_prompt" not in state:
- state["geometric_prompt"] = self.model._get_dummy_prompt()
- return self._forward_grounding(state)
- @torch.inference_mode()
- def add_geometric_prompt(self, box: List, label: bool, state: Dict):
- """Adds a box prompt and run the inference.
- The image needs to be set, but not necessarily the text prompt.
- The box is assumed to be in [center_x, center_y, width, height] format and normalized in [0, 1] range.
- The label is True for a positive box, False for a negative box.
- """
- if "backbone_out" not in state:
- raise ValueError("You must call set_image before set_text_prompt")
- if "language_features" not in state["backbone_out"]:
- # Looks like we don't have a text prompt yet. This is allowed, but we need to set the text prompt to "visual" for the model to rely only on the geometric prompt
- dummy_text_outputs = self.model.backbone.forward_text(
- ["visual"], device=self.device
- )
- state["backbone_out"].update(dummy_text_outputs)
- if "geometric_prompt" not in state:
- state["geometric_prompt"] = self.model._get_dummy_prompt()
- # adding a batch and sequence dimension
- boxes = torch.tensor(box, device=self.device, dtype=torch.float32).view(1, 1, 4)
- labels = torch.tensor([label], device=self.device, dtype=torch.bool).view(1, 1)
- state["geometric_prompt"].append_boxes(boxes, labels)
- return self._forward_grounding(state)
- def reset_all_prompts(self, state: Dict):
- """Removes all the prompts and results"""
- if "backbone_out" in state:
- backbone_keys_to_del = [
- "language_features",
- "language_mask",
- "language_embeds",
- ]
- for key in backbone_keys_to_del:
- if key in state["backbone_out"]:
- del state["backbone_out"][key]
- keys_to_del = ["geometric_prompt", "boxes", "masks", "masks_logits", "scores"]
- for key in keys_to_del:
- if key in state:
- del state[key]
- @torch.inference_mode()
- def set_confidence_threshold(self, threshold: float, state=None):
- """Sets the confidence threshold for the masks"""
- self.confidence_threshold = threshold
- if state is not None and "boxes" in state:
- # we need to filter the boxes again
- # In principle we could do this more efficiently since we would only need
- # to rerun the heads. But this is simpler and not too inefficient
- return self._forward_grounding(state)
- return state
- @torch.inference_mode()
- def _forward_grounding(self, state: Dict):
- outputs = self.model.forward_grounding(
- backbone_out=state["backbone_out"],
- find_input=self.find_stage,
- geometric_prompt=state["geometric_prompt"],
- find_target=None,
- )
- out_bbox = outputs["pred_boxes"]
- out_logits = outputs["pred_logits"]
- out_masks = outputs["pred_masks"]
- out_probs = out_logits.sigmoid()
- presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1)
- out_probs = (out_probs * presence_score).squeeze(-1)
- keep = out_probs > self.confidence_threshold
- out_probs = out_probs[keep]
- out_masks = out_masks[keep]
- out_bbox = out_bbox[keep]
- # convert to [x0, y0, x1, y1] format
- boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
- img_h = state["original_height"]
- img_w = state["original_width"]
- scale_fct = torch.tensor([img_w, img_h, img_w, img_h]).to(self.device)
- boxes = boxes * scale_fct[None, :]
- out_masks = interpolate(
- out_masks.unsqueeze(1),
- (img_h, img_w),
- mode="bilinear",
- align_corners=False,
- ).sigmoid()
- state["masks_logits"] = out_masks
- state["masks"] = out_masks > 0.5
- state["boxes"] = boxes
- state["scores"] = out_probs
- return state
|