sam3_image_processor.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. from typing import Dict, List
  4. import numpy as np
  5. import PIL
  6. import torch
  7. from sam3.model import box_ops
  8. from sam3.model.data_misc import FindStage, interpolate
  9. from torchvision.transforms import v2
  10. class Sam3Processor:
  11. """ """
  12. def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5):
  13. self.model = model
  14. self.resolution = resolution
  15. self.device = device
  16. self.transform = v2.Compose(
  17. [
  18. v2.ToDtype(torch.uint8, scale=True),
  19. v2.Resize(size=(resolution, resolution)),
  20. v2.ToDtype(torch.float32, scale=True),
  21. v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
  22. ]
  23. )
  24. self.confidence_threshold = confidence_threshold
  25. self.find_stage = FindStage(
  26. img_ids=torch.tensor([0], device=device, dtype=torch.long),
  27. text_ids=torch.tensor([0], device=device, dtype=torch.long),
  28. input_boxes=None,
  29. input_boxes_mask=None,
  30. input_boxes_label=None,
  31. input_points=None,
  32. input_points_mask=None,
  33. )
  34. @torch.inference_mode()
  35. def set_image(self, image, state=None):
  36. """Sets the image on which we want to do predictions."""
  37. if state is None:
  38. state = {}
  39. if isinstance(image, PIL.Image.Image):
  40. width, height = image.size
  41. elif isinstance(image, (torch.Tensor, np.ndarray)):
  42. height, width = image.shape[-2:]
  43. else:
  44. raise ValueError("Image must be a PIL image or a tensor")
  45. image = v2.functional.to_image(image).to(self.device)
  46. image = self.transform(image).unsqueeze(0)
  47. state["original_height"] = height
  48. state["original_width"] = width
  49. state["backbone_out"] = self.model.backbone.forward_image(image)
  50. inst_interactivity_en = self.model.inst_interactive_predictor is not None
  51. if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]:
  52. sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"]
  53. sam2_backbone_out["backbone_fpn"][0] = (
  54. self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0(
  55. sam2_backbone_out["backbone_fpn"][0]
  56. )
  57. )
  58. sam2_backbone_out["backbone_fpn"][1] = (
  59. self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1(
  60. sam2_backbone_out["backbone_fpn"][1]
  61. )
  62. )
  63. return state
  64. @torch.inference_mode()
  65. def set_image_batch(self, images: List[np.ndarray], state=None):
  66. """Sets the image batch on which we want to do predictions."""
  67. if state is None:
  68. state = {}
  69. if not isinstance(images, list):
  70. raise ValueError("Images must be a list of PIL images or tensors")
  71. assert len(images) > 0, "Images list must not be empty"
  72. assert isinstance(images[0], PIL.Image.Image), (
  73. "Images must be a list of PIL images"
  74. )
  75. state["original_heights"] = [image.height for image in images]
  76. state["original_widths"] = [image.width for image in images]
  77. images = [
  78. self.transform(v2.functional.to_image(image).to(self.device))
  79. for image in images
  80. ]
  81. images = torch.stack(images, dim=0)
  82. state["backbone_out"] = self.model.backbone.forward_image(images)
  83. inst_interactivity_en = self.model.inst_interactive_predictor is not None
  84. if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]:
  85. sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"]
  86. sam2_backbone_out["backbone_fpn"][0] = (
  87. self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0(
  88. sam2_backbone_out["backbone_fpn"][0]
  89. )
  90. )
  91. sam2_backbone_out["backbone_fpn"][1] = (
  92. self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1(
  93. sam2_backbone_out["backbone_fpn"][1]
  94. )
  95. )
  96. return state
  97. @torch.inference_mode()
  98. def set_text_prompt(self, prompt: str, state: Dict):
  99. """Sets the text prompt and run the inference"""
  100. if "backbone_out" not in state:
  101. raise ValueError("You must call set_image before set_text_prompt")
  102. text_outputs = self.model.backbone.forward_text([prompt], device=self.device)
  103. # will erase the previous text prompt if any
  104. state["backbone_out"].update(text_outputs)
  105. if "geometric_prompt" not in state:
  106. state["geometric_prompt"] = self.model._get_dummy_prompt()
  107. return self._forward_grounding(state)
  108. @torch.inference_mode()
  109. def add_geometric_prompt(self, box: List, label: bool, state: Dict):
  110. """Adds a box prompt and run the inference.
  111. The image needs to be set, but not necessarily the text prompt.
  112. The box is assumed to be in [center_x, center_y, width, height] format and normalized in [0, 1] range.
  113. The label is True for a positive box, False for a negative box.
  114. """
  115. if "backbone_out" not in state:
  116. raise ValueError("You must call set_image before set_text_prompt")
  117. if "language_features" not in state["backbone_out"]:
  118. # Looks like we don't have a text prompt yet. This is allowed, but we need to set the text prompt to "visual" for the model to rely only on the geometric prompt
  119. dummy_text_outputs = self.model.backbone.forward_text(
  120. ["visual"], device=self.device
  121. )
  122. state["backbone_out"].update(dummy_text_outputs)
  123. if "geometric_prompt" not in state:
  124. state["geometric_prompt"] = self.model._get_dummy_prompt()
  125. # adding a batch and sequence dimension
  126. boxes = torch.tensor(box, device=self.device, dtype=torch.float32).view(1, 1, 4)
  127. labels = torch.tensor([label], device=self.device, dtype=torch.bool).view(1, 1)
  128. state["geometric_prompt"].append_boxes(boxes, labels)
  129. return self._forward_grounding(state)
  130. def reset_all_prompts(self, state: Dict):
  131. """Removes all the prompts and results"""
  132. if "backbone_out" in state:
  133. backbone_keys_to_del = [
  134. "language_features",
  135. "language_mask",
  136. "language_embeds",
  137. ]
  138. for key in backbone_keys_to_del:
  139. if key in state["backbone_out"]:
  140. del state["backbone_out"][key]
  141. keys_to_del = ["geometric_prompt", "boxes", "masks", "masks_logits", "scores"]
  142. for key in keys_to_del:
  143. if key in state:
  144. del state[key]
  145. @torch.inference_mode()
  146. def set_confidence_threshold(self, threshold: float, state=None):
  147. """Sets the confidence threshold for the masks"""
  148. self.confidence_threshold = threshold
  149. if state is not None and "boxes" in state:
  150. # we need to filter the boxes again
  151. # In principle we could do this more efficiently since we would only need
  152. # to rerun the heads. But this is simpler and not too inefficient
  153. return self._forward_grounding(state)
  154. return state
  155. @torch.inference_mode()
  156. def _forward_grounding(self, state: Dict):
  157. outputs = self.model.forward_grounding(
  158. backbone_out=state["backbone_out"],
  159. find_input=self.find_stage,
  160. geometric_prompt=state["geometric_prompt"],
  161. find_target=None,
  162. )
  163. out_bbox = outputs["pred_boxes"]
  164. out_logits = outputs["pred_logits"]
  165. out_masks = outputs["pred_masks"]
  166. out_probs = out_logits.sigmoid()
  167. presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1)
  168. out_probs = (out_probs * presence_score).squeeze(-1)
  169. keep = out_probs > self.confidence_threshold
  170. out_probs = out_probs[keep]
  171. out_masks = out_masks[keep]
  172. out_bbox = out_bbox[keep]
  173. # convert to [x0, y0, x1, y1] format
  174. boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
  175. img_h = state["original_height"]
  176. img_w = state["original_width"]
  177. scale_fct = torch.tensor([img_w, img_h, img_w, img_h]).to(self.device)
  178. boxes = boxes * scale_fct[None, :]
  179. out_masks = interpolate(
  180. out_masks.unsqueeze(1),
  181. (img_h, img_w),
  182. mode="bilinear",
  183. align_corners=False,
  184. ).sigmoid()
  185. state["masks_logits"] = out_masks
  186. state["masks"] = out_masks > 0.5
  187. state["boxes"] = boxes
  188. state["scores"] = out_probs
  189. return state