sam3_video_inference.py 75 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import logging
  4. from collections import defaultdict
  5. import numpy as np
  6. import torch
  7. import torch.distributed as dist
  8. import torch.nn.functional as F
  9. from sam3 import perflib
  10. from sam3.logger import get_logger
  11. from sam3.model.act_ckpt_utils import clone_output_wrapper
  12. from sam3.model.box_ops import box_xywh_to_cxcywh, box_xyxy_to_xywh
  13. from sam3.model.data_misc import BatchedDatapoint, convert_my_tensors, FindStage
  14. from sam3.model.geometry_encoders import Prompt
  15. from sam3.model.io_utils import IMAGE_EXTS, load_resource_as_video_frames
  16. from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores
  17. from sam3.model.sam3_video_base import MaskletConfirmationStatus, Sam3VideoBase
  18. from sam3.model.utils.misc import copy_data_to_device
  19. from sam3.perflib.compile import compile_wrapper, shape_logging_wrapper
  20. from sam3.perflib.masks_ops import masks_to_boxes as perf_masks_to_boxes
  21. from torchvision.ops import masks_to_boxes
  22. from tqdm.auto import tqdm
  23. logger = get_logger(__name__)
  24. class Sam3VideoInference(Sam3VideoBase):
  25. TEXT_ID_FOR_TEXT = 0
  26. TEXT_ID_FOR_VISUAL = 1
  27. def __init__(
  28. self,
  29. image_size=1008,
  30. image_mean=(0.5, 0.5, 0.5),
  31. image_std=(0.5, 0.5, 0.5),
  32. compile_model=False,
  33. **kwargs,
  34. ):
  35. """
  36. hotstart_delay: int, the delay (in #frames) before the model starts to yield output, 0 to disable hotstart delay.
  37. hotstart_unmatch_thresh: int, remove the object if it has this many unmatched frames within its hotstart_delay period.
  38. If `hotstart_delay` is set to 0, this parameter is ignored.
  39. hotstart_dup_thresh: int, remove the object if it has overlapped with another object this many frames within its hotstart_delay period.
  40. """
  41. super().__init__(**kwargs)
  42. self.image_size = image_size
  43. self.image_mean = image_mean
  44. self.image_std = image_std
  45. self.compile_model = compile_model
  46. @torch.inference_mode()
  47. def init_state(
  48. self,
  49. resource_path,
  50. offload_video_to_cpu=False,
  51. async_loading_frames=False,
  52. video_loader_type="cv2",
  53. ):
  54. """Initialize an inference state from `resource_path` (an image or a video)."""
  55. images, orig_height, orig_width = load_resource_as_video_frames(
  56. resource_path=resource_path,
  57. image_size=self.image_size,
  58. offload_video_to_cpu=offload_video_to_cpu,
  59. img_mean=self.image_mean,
  60. img_std=self.image_std,
  61. async_loading_frames=async_loading_frames,
  62. video_loader_type=video_loader_type,
  63. )
  64. inference_state = {}
  65. inference_state["image_size"] = self.image_size
  66. inference_state["num_frames"] = len(images)
  67. # the original video height and width, used for resizing final output scores
  68. inference_state["orig_height"] = orig_height
  69. inference_state["orig_width"] = orig_width
  70. # values that don't change across frames (so we only need to hold one copy of them)
  71. inference_state["constants"] = {}
  72. # inputs on each frame
  73. self._construct_initial_input_batch(inference_state, images)
  74. # initialize extra states
  75. inference_state["tracker_inference_states"] = []
  76. inference_state["tracker_metadata"] = {}
  77. inference_state["feature_cache"] = {}
  78. inference_state["cached_frame_outputs"] = {}
  79. inference_state["action_history"] = [] # for logging user actions
  80. inference_state["is_image_only"] = is_image_type(resource_path)
  81. return inference_state
  82. @torch.inference_mode()
  83. def reset_state(self, inference_state):
  84. """Revert `inference_state` to what it was right after initialization."""
  85. inference_state["input_batch"].find_text_batch[0] = "<text placeholder>"
  86. inference_state["text_prompt"] = None
  87. for t in range(inference_state["num_frames"]):
  88. inference_state["input_batch"].find_inputs[t].text_ids[...] = 0
  89. # constructing an output list in inference state (we start with an empty list)
  90. inference_state["previous_stages_out"][t] = None
  91. inference_state["per_frame_raw_point_input"][t] = None
  92. inference_state["per_frame_raw_box_input"][t] = None
  93. inference_state["per_frame_visual_prompt"][t] = None
  94. inference_state["per_frame_geometric_prompt"][t] = None
  95. inference_state["per_frame_cur_step"][t] = 0
  96. inference_state["visual_prompt_embed"] = None
  97. inference_state["visual_prompt_mask"] = None
  98. inference_state["tracker_inference_states"].clear()
  99. inference_state["tracker_metadata"].clear()
  100. inference_state["feature_cache"].clear()
  101. inference_state["cached_frame_outputs"].clear()
  102. inference_state["action_history"].clear() # for logging user actions
  103. def _construct_initial_input_batch(self, inference_state, images):
  104. """Construct an initial `BatchedDatapoint` instance as input."""
  105. # 1) img_batch
  106. num_frames = len(images)
  107. device = self.device
  108. # 2) find_text_batch
  109. # "<text placeholder>" will be replaced by the actual text prompt when adding prompts
  110. find_text_batch = ["<text placeholder>", "visual"]
  111. # 3) find_inputs
  112. input_box_embedding_dim = 258 # historical default
  113. input_points_embedding_dim = 257 # historical default
  114. stages = [
  115. FindStage(
  116. img_ids=[stage_id],
  117. text_ids=[0],
  118. input_boxes=[torch.zeros(input_box_embedding_dim)],
  119. input_boxes_mask=[torch.empty(0, dtype=torch.bool)],
  120. input_boxes_label=[torch.empty(0, dtype=torch.long)],
  121. input_points=[torch.empty(0, input_points_embedding_dim)],
  122. input_points_mask=[torch.empty(0)],
  123. object_ids=[],
  124. )
  125. for stage_id in range(num_frames)
  126. ]
  127. for i in range(len(stages)):
  128. stages[i] = convert_my_tensors(stages[i])
  129. # construct the final `BatchedDatapoint` and cast to GPU
  130. input_batch = BatchedDatapoint(
  131. img_batch=images,
  132. find_text_batch=find_text_batch,
  133. find_inputs=stages,
  134. find_targets=[None] * num_frames,
  135. find_metadatas=[None] * num_frames,
  136. )
  137. input_batch = copy_data_to_device(input_batch, device, non_blocking=True)
  138. inference_state["input_batch"] = input_batch
  139. # construct the placeholder interactive prompts and tracking queries
  140. bs = 1
  141. inference_state["constants"]["empty_geometric_prompt"] = Prompt(
  142. box_embeddings=torch.zeros(0, bs, 4, device=device),
  143. box_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool),
  144. box_labels=torch.zeros(0, bs, device=device, dtype=torch.long),
  145. point_embeddings=torch.zeros(0, bs, 2, device=device),
  146. point_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool),
  147. point_labels=torch.zeros(0, bs, device=device, dtype=torch.long),
  148. )
  149. # constructing an output list in inference state (we start with an empty list)
  150. inference_state["previous_stages_out"] = [None] * num_frames
  151. inference_state["text_prompt"] = None
  152. inference_state["per_frame_raw_point_input"] = [None] * num_frames
  153. inference_state["per_frame_raw_box_input"] = [None] * num_frames
  154. inference_state["per_frame_visual_prompt"] = [None] * num_frames
  155. inference_state["per_frame_geometric_prompt"] = [None] * num_frames
  156. inference_state["per_frame_cur_step"] = [0] * num_frames
  157. # placeholders for cached outputs
  158. # (note: currently, a single visual prompt embedding is shared for all frames)
  159. inference_state["visual_prompt_embed"] = None
  160. inference_state["visual_prompt_mask"] = None
  161. def _get_visual_prompt(self, inference_state, frame_idx, boxes_cxcywh, box_labels):
  162. """
  163. Handle the case of visual prompt. Currently, in the inference API we do not
  164. explicitly distinguish between initial box as visual prompt vs subsequent boxes
  165. or boxes after inference for refinement.
  166. """
  167. # If the frame hasn't had any inference results before (prompting or propagation),
  168. # we treat the first added box prompt as a visual prompt; otherwise, we treat
  169. # the first box just as a refinement prompt.
  170. is_new_visual_prompt = (
  171. inference_state["per_frame_visual_prompt"][frame_idx] is None
  172. and inference_state["previous_stages_out"][frame_idx] is None
  173. )
  174. if is_new_visual_prompt:
  175. if boxes_cxcywh.size(0) != 1:
  176. raise RuntimeError(
  177. "visual prompts (box as an initial prompt) should only have one box, "
  178. f"but got {boxes_cxcywh.shape=}"
  179. )
  180. if not box_labels.item():
  181. logging.warning("A negative box is added as a visual prompt.")
  182. # take the first box prompt as a visual prompt
  183. device = self.device
  184. new_visual_prompt = Prompt(
  185. box_embeddings=boxes_cxcywh[None, 0:1, :].to(device), # (seq, bs, 4)
  186. box_mask=None,
  187. box_labels=box_labels[None, 0:1].to(device), # (seq, bs)
  188. point_embeddings=None,
  189. point_mask=None,
  190. point_labels=None,
  191. )
  192. inference_state["per_frame_visual_prompt"][frame_idx] = new_visual_prompt
  193. else:
  194. new_visual_prompt = None
  195. # `boxes_cxcywh` and `box_labels` contains all the raw box inputs added so far
  196. # strip any visual prompt from the input boxes (for geometric prompt encoding)
  197. if inference_state["per_frame_visual_prompt"][frame_idx] is not None:
  198. boxes_cxcywh = boxes_cxcywh[1:]
  199. box_labels = box_labels[1:]
  200. return boxes_cxcywh, box_labels, new_visual_prompt
  201. def _get_processing_order(
  202. self, inference_state, start_frame_idx, max_frame_num_to_track, reverse
  203. ):
  204. num_frames = inference_state["num_frames"]
  205. previous_stages_out = inference_state["previous_stages_out"]
  206. if all(out is None for out in previous_stages_out) and start_frame_idx is None:
  207. raise RuntimeError(
  208. "No prompts are received on any frames. Please add prompt on at least one frame before propagation."
  209. )
  210. # set start index, end index, and processing order
  211. if start_frame_idx is None:
  212. # default: start from the earliest frame with input points
  213. start_frame_idx = min(
  214. t for t, out in enumerate(previous_stages_out) if out is not None
  215. )
  216. if max_frame_num_to_track is None:
  217. # default: track all the frames in the video
  218. max_frame_num_to_track = num_frames
  219. if reverse:
  220. end_frame_idx = start_frame_idx - max_frame_num_to_track
  221. end_frame_idx = max(end_frame_idx, 0)
  222. processing_order = range(start_frame_idx - 1, end_frame_idx - 1, -1)
  223. else:
  224. end_frame_idx = start_frame_idx + max_frame_num_to_track
  225. end_frame_idx = min(end_frame_idx, num_frames - 1)
  226. processing_order = range(start_frame_idx, end_frame_idx + 1)
  227. return processing_order, end_frame_idx
  228. @torch.inference_mode()
  229. def propagate_in_video(
  230. self,
  231. inference_state,
  232. start_frame_idx=None,
  233. max_frame_num_to_track=None,
  234. reverse=False,
  235. ):
  236. """
  237. Propagate the prompts to get grounding results for the entire video. This method
  238. is a generator and yields inference outputs for all frames in the range specified
  239. by `start_frame_idx`, `max_frame_num_to_track`, and `reverse`.
  240. """
  241. # compile the model (it's a no-op if the model is already compiled)
  242. # note that it's intentionally added to `self.propagate_in_video`, so that the first
  243. # `self.add_prompt` call will be done in eager mode to fill in the decoder buffers
  244. # such as positional encoding cache)
  245. self._compile_model()
  246. processing_order, end_frame_idx = self._get_processing_order(
  247. inference_state,
  248. start_frame_idx,
  249. max_frame_num_to_track,
  250. reverse=reverse,
  251. )
  252. # Store max_frame_num_to_track in feature_cache for downstream methods
  253. inference_state["feature_cache"]["tracking_bounds"] = {
  254. "max_frame_num_to_track": max_frame_num_to_track,
  255. "propagate_in_video_start_frame_idx": start_frame_idx,
  256. }
  257. hotstart_buffer = []
  258. hotstart_removed_obj_ids = set()
  259. # when deciding whether to output a masklet on `yield_frame_idx`, we check whether the object is confirmed
  260. # in a future frame (`unconfirmed_frame_delay` frames after the current frame). For example, if we require
  261. # an object to be detected in 3 consecutive frames to be confirmed, then we look 2 frames in the future --
  262. # e.g., we output an object on frame 4 only if it becomes confirmed on frame 6.
  263. unconfirmed_status_delay = self.masklet_confirmation_consecutive_det_thresh - 1
  264. unconfirmed_obj_ids_per_frame = {} # frame_idx -> hidden_obj_ids
  265. for frame_idx in tqdm(
  266. processing_order, desc="propagate_in_video", disable=self.rank > 0
  267. ):
  268. out = self._run_single_frame_inference(inference_state, frame_idx, reverse)
  269. if self.hotstart_delay > 0:
  270. # accumulate the outputs for the first `hotstart_delay` frames
  271. hotstart_buffer.append([frame_idx, out])
  272. # update the object IDs removed by hotstart so that we don't output them
  273. if self.rank == 0:
  274. hotstart_removed_obj_ids.update(out["removed_obj_ids"])
  275. unconfirmed_obj_ids = out.get("unconfirmed_obj_ids", None)
  276. if unconfirmed_obj_ids is not None:
  277. unconfirmed_obj_ids_per_frame[frame_idx] = unconfirmed_obj_ids
  278. if frame_idx == end_frame_idx:
  279. # we reached the end of propagation -- yield all frames in the buffer
  280. yield_list = hotstart_buffer
  281. hotstart_buffer = []
  282. elif len(hotstart_buffer) >= self.hotstart_delay:
  283. # we have enough frames -- yield and remove the first (oldest) frame from the buffer
  284. yield_list = hotstart_buffer[:1]
  285. hotstart_buffer = hotstart_buffer[1:]
  286. else:
  287. # not enough frames yet -- skip yielding
  288. yield_list = []
  289. else:
  290. yield_list = [(frame_idx, out)] # output the current frame
  291. for yield_frame_idx, yield_out in yield_list:
  292. # post-process the output and yield it
  293. if self.rank == 0:
  294. suppressed_obj_ids = yield_out["suppressed_obj_ids"]
  295. unconfirmed_status_frame_idx = (
  296. yield_frame_idx + unconfirmed_status_delay
  297. if not reverse
  298. else yield_frame_idx - unconfirmed_status_delay
  299. )
  300. # Clamp the frame index to stay within video bounds
  301. num_frames = inference_state["num_frames"]
  302. unconfirmed_status_frame_idx = max(
  303. 0, min(unconfirmed_status_frame_idx, num_frames - 1)
  304. )
  305. unconfirmed_obj_ids = unconfirmed_obj_ids_per_frame.get(
  306. unconfirmed_status_frame_idx, None
  307. )
  308. postprocessed_out = self._postprocess_output(
  309. inference_state,
  310. yield_out,
  311. hotstart_removed_obj_ids,
  312. suppressed_obj_ids,
  313. unconfirmed_obj_ids,
  314. )
  315. self._cache_frame_outputs(
  316. inference_state,
  317. yield_frame_idx,
  318. yield_out["obj_id_to_mask"],
  319. suppressed_obj_ids=suppressed_obj_ids,
  320. removed_obj_ids=hotstart_removed_obj_ids,
  321. unconfirmed_obj_ids=unconfirmed_obj_ids,
  322. )
  323. else:
  324. postprocessed_out = None # no output on other GPUs
  325. yield yield_frame_idx, postprocessed_out
  326. def _run_single_frame_inference(self, inference_state, frame_idx, reverse):
  327. """
  328. Perform inference on a single frame and get its inference results. This would
  329. also update `inference_state`.
  330. """
  331. # prepare inputs
  332. input_batch = inference_state["input_batch"]
  333. tracker_states_local = inference_state["tracker_inference_states"]
  334. has_text_prompt = inference_state["text_prompt"] is not None
  335. has_geometric_prompt = (
  336. inference_state["per_frame_geometric_prompt"][frame_idx] is not None
  337. )
  338. # run inference for the current frame
  339. (
  340. obj_id_to_mask,
  341. obj_id_to_score,
  342. tracker_states_local_new,
  343. tracker_metadata_new,
  344. frame_stats,
  345. _,
  346. ) = self._det_track_one_frame(
  347. frame_idx=frame_idx,
  348. num_frames=inference_state["num_frames"],
  349. reverse=reverse,
  350. input_batch=input_batch,
  351. geometric_prompt=(
  352. inference_state["constants"]["empty_geometric_prompt"]
  353. if not has_geometric_prompt
  354. else inference_state["per_frame_geometric_prompt"][frame_idx]
  355. ),
  356. tracker_states_local=tracker_states_local,
  357. tracker_metadata_prev=inference_state["tracker_metadata"],
  358. feature_cache=inference_state["feature_cache"],
  359. orig_vid_height=inference_state["orig_height"],
  360. orig_vid_width=inference_state["orig_width"],
  361. is_image_only=inference_state["is_image_only"],
  362. allow_new_detections=has_text_prompt or has_geometric_prompt,
  363. )
  364. # update inference state
  365. inference_state["tracker_inference_states"] = tracker_states_local_new
  366. inference_state["tracker_metadata"] = tracker_metadata_new
  367. # use a dummy string in "previous_stages_out" to indicate this frame has outputs
  368. inference_state["previous_stages_out"][frame_idx] = "_THIS_FRAME_HAS_OUTPUTS_"
  369. if self.rank == 0:
  370. self._cache_frame_outputs(inference_state, frame_idx, obj_id_to_mask)
  371. out = {
  372. "obj_id_to_mask": obj_id_to_mask,
  373. "obj_id_to_score": obj_id_to_score, # first frame detection score
  374. "obj_id_to_tracker_score": tracker_metadata_new[
  375. "obj_id_to_tracker_score_frame_wise"
  376. ][frame_idx],
  377. }
  378. # removed_obj_ids is only needed on rank 0 to handle hotstart delay buffer
  379. if self.rank == 0:
  380. rank0_metadata = tracker_metadata_new["rank0_metadata"]
  381. removed_obj_ids = rank0_metadata["removed_obj_ids"]
  382. out["removed_obj_ids"] = removed_obj_ids
  383. out["suppressed_obj_ids"] = rank0_metadata["suppressed_obj_ids"][frame_idx]
  384. out["frame_stats"] = frame_stats
  385. if self.masklet_confirmation_enable:
  386. status = rank0_metadata["masklet_confirmation"]["status"]
  387. is_unconfirmed = status == MaskletConfirmationStatus.UNCONFIRMED.value
  388. out["unconfirmed_obj_ids"] = tracker_metadata_new["obj_ids_all_gpu"][
  389. is_unconfirmed
  390. ].tolist()
  391. else:
  392. out["unconfirmed_obj_ids"] = []
  393. return out
  394. def _postprocess_output(
  395. self,
  396. inference_state,
  397. out,
  398. removed_obj_ids=None,
  399. suppressed_obj_ids=None,
  400. unconfirmed_obj_ids=None,
  401. ):
  402. obj_id_to_mask = out["obj_id_to_mask"] # low res masks
  403. curr_obj_ids = sorted(obj_id_to_mask.keys())
  404. H_video, W_video = inference_state["orig_height"], inference_state["orig_width"]
  405. if len(curr_obj_ids) == 0:
  406. out_obj_ids = torch.zeros(0, dtype=torch.int64)
  407. out_probs = torch.zeros(0, dtype=torch.float32)
  408. out_binary_masks = torch.zeros(0, H_video, W_video, dtype=torch.bool)
  409. out_boxes_xywh = torch.zeros(0, 4, dtype=torch.float32)
  410. else:
  411. out_obj_ids = torch.tensor(curr_obj_ids, dtype=torch.int64)
  412. out_probs = torch.tensor(
  413. [out["obj_id_to_score"][obj_id] for obj_id in curr_obj_ids]
  414. )
  415. out_tracker_probs = torch.tensor(
  416. [
  417. (
  418. out["obj_id_to_tracker_score"][obj_id]
  419. if obj_id in out["obj_id_to_tracker_score"]
  420. else 0.0
  421. )
  422. for obj_id in curr_obj_ids
  423. ]
  424. )
  425. out_binary_masks = torch.cat(
  426. [obj_id_to_mask[obj_id] for obj_id in curr_obj_ids], dim=0
  427. )
  428. assert out_binary_masks.dtype == torch.bool
  429. keep = out_binary_masks.any(dim=(1, 2)).cpu() # remove masks with 0 areas
  430. # hide outputs for those object IDs in `obj_ids_to_hide`
  431. obj_ids_to_hide = []
  432. if suppressed_obj_ids is not None:
  433. obj_ids_to_hide.extend(suppressed_obj_ids)
  434. if removed_obj_ids is not None:
  435. obj_ids_to_hide.extend(removed_obj_ids)
  436. if unconfirmed_obj_ids is not None:
  437. obj_ids_to_hide.extend(unconfirmed_obj_ids)
  438. if len(obj_ids_to_hide) > 0:
  439. obj_ids_to_hide_t = torch.tensor(obj_ids_to_hide, dtype=torch.int64)
  440. keep &= ~torch.isin(out_obj_ids, obj_ids_to_hide_t)
  441. # slice those valid entries from the original outputs
  442. keep_idx = torch.nonzero(keep, as_tuple=True)[0]
  443. keep_idx_gpu = keep_idx.pin_memory().to(
  444. device=out_binary_masks.device, non_blocking=True
  445. )
  446. out_obj_ids = torch.index_select(out_obj_ids, 0, keep_idx)
  447. out_probs = torch.index_select(out_probs, 0, keep_idx)
  448. out_tracker_probs = torch.index_select(out_tracker_probs, 0, keep_idx)
  449. out_binary_masks = torch.index_select(out_binary_masks, 0, keep_idx_gpu)
  450. if perflib.is_enabled:
  451. out_boxes_xyxy = perf_masks_to_boxes(
  452. out_binary_masks, out_obj_ids.tolist()
  453. )
  454. else:
  455. out_boxes_xyxy = masks_to_boxes(out_binary_masks)
  456. out_boxes_xywh = box_xyxy_to_xywh(out_boxes_xyxy) # convert to xywh format
  457. # normalize boxes
  458. out_boxes_xywh[..., 0] /= W_video
  459. out_boxes_xywh[..., 1] /= H_video
  460. out_boxes_xywh[..., 2] /= W_video
  461. out_boxes_xywh[..., 3] /= H_video
  462. # apply non-overlapping constraints on the existing masklets
  463. if out_binary_masks.shape[0] > 1:
  464. assert len(out_binary_masks) == len(out_tracker_probs)
  465. out_binary_masks = (
  466. self.tracker._apply_object_wise_non_overlapping_constraints(
  467. out_binary_masks.unsqueeze(1),
  468. out_tracker_probs.unsqueeze(1).to(out_binary_masks.device),
  469. background_value=0,
  470. ).squeeze(1)
  471. ) > 0
  472. outputs = {
  473. "out_obj_ids": out_obj_ids.cpu().numpy(),
  474. "out_probs": out_probs.cpu().numpy(),
  475. "out_boxes_xywh": out_boxes_xywh.cpu().numpy(),
  476. "out_binary_masks": out_binary_masks.cpu().numpy(),
  477. "frame_stats": out.get("frame_stats", None),
  478. }
  479. return outputs
  480. def _cache_frame_outputs(
  481. self,
  482. inference_state,
  483. frame_idx,
  484. obj_id_to_mask,
  485. suppressed_obj_ids=None,
  486. removed_obj_ids=None,
  487. unconfirmed_obj_ids=None,
  488. ):
  489. # Filter out suppressed, removed, and unconfirmed objects from the cache
  490. filtered_obj_id_to_mask = obj_id_to_mask.copy()
  491. objects_to_exclude = set()
  492. if suppressed_obj_ids is not None:
  493. objects_to_exclude.update(suppressed_obj_ids)
  494. if removed_obj_ids is not None:
  495. objects_to_exclude.update(removed_obj_ids)
  496. if unconfirmed_obj_ids is not None:
  497. objects_to_exclude.update(unconfirmed_obj_ids)
  498. if objects_to_exclude:
  499. for obj_id in objects_to_exclude:
  500. if obj_id in filtered_obj_id_to_mask:
  501. del filtered_obj_id_to_mask[obj_id]
  502. inference_state["cached_frame_outputs"][frame_idx] = filtered_obj_id_to_mask
  503. def _build_tracker_output(
  504. self, inference_state, frame_idx, refined_obj_id_to_mask=None
  505. ):
  506. assert (
  507. "cached_frame_outputs" in inference_state
  508. and frame_idx in inference_state["cached_frame_outputs"]
  509. ), (
  510. "No cached outputs found. Ensure normal propagation has run first to populate the cache."
  511. )
  512. cached_outputs = inference_state["cached_frame_outputs"][frame_idx]
  513. obj_id_to_mask = cached_outputs.copy()
  514. # Update with refined masks if provided
  515. if refined_obj_id_to_mask is not None:
  516. for obj_id, refined_mask in refined_obj_id_to_mask.items():
  517. assert refined_mask is not None, (
  518. f"Refined mask data must be provided for obj_id {obj_id}"
  519. )
  520. obj_id_to_mask[obj_id] = refined_mask
  521. return obj_id_to_mask
  522. def _compile_model(self):
  523. """Compile the SAM model with torch.compile for speedup."""
  524. is_compiled = getattr(self, "_model_is_compiled", False)
  525. if is_compiled or not self.compile_model:
  526. return
  527. import torch._dynamo
  528. # a larger cache size to hold varying number of shapes for torch.compile
  529. # see https://github.com/pytorch/pytorch/blob/v2.5.1/torch/_dynamo/config.py#L42-L49
  530. torch._dynamo.config.cache_size_limit = 128
  531. torch._dynamo.config.accumulated_cache_size_limit = 2048
  532. torch._dynamo.config.capture_scalar_outputs = True
  533. torch._dynamo.config.suppress_errors = True
  534. # Compile module components
  535. # skip compilation of `_encode_prompt` since it sometimes tiggger SymInt errors
  536. # self._encode_prompt = clone_output_wrapper(
  537. # torch.compile(self._encode_prompt, fullgraph=True, mode="max-autotune")
  538. # )
  539. ## Compile SAM3 model components
  540. self.detector.backbone.vision_backbone.forward = clone_output_wrapper(
  541. torch.compile(
  542. self.detector.backbone.vision_backbone.forward,
  543. fullgraph=True,
  544. mode="max-autotune",
  545. )
  546. )
  547. self.detector.transformer.encoder.forward = clone_output_wrapper(
  548. torch.compile(
  549. self.detector.transformer.encoder.forward,
  550. fullgraph=True,
  551. mode="max-autotune",
  552. )
  553. )
  554. self.detector.transformer.decoder.forward = clone_output_wrapper(
  555. torch.compile(
  556. self.detector.transformer.decoder.forward,
  557. fullgraph=True,
  558. mode="max-autotune",
  559. dynamic=False,
  560. )
  561. )
  562. self.detector.segmentation_head.forward = clone_output_wrapper(
  563. torch.compile(
  564. self.detector.segmentation_head.forward,
  565. fullgraph=True,
  566. mode="max-autotune",
  567. )
  568. )
  569. ## Compile Tracker model components
  570. self.tracker.maskmem_backbone.forward = compile_wrapper(
  571. self.tracker.maskmem_backbone.forward,
  572. mode="max-autotune",
  573. fullgraph=True,
  574. dynamic=False,
  575. )
  576. self.tracker.transformer.encoder.forward = shape_logging_wrapper(
  577. compile_wrapper(
  578. self.tracker.transformer.encoder.forward,
  579. mode="max-autotune-no-cudagraphs",
  580. fullgraph=True,
  581. dynamic=True,
  582. ),
  583. keep_kwargs=["src", "src_pos", "prompt", "prompt_pos"],
  584. )
  585. self.tracker.sam_mask_decoder.forward = compile_wrapper(
  586. self.tracker.sam_mask_decoder.forward,
  587. mode="max-autotune",
  588. fullgraph=True,
  589. dynamic=False, # Accuracy regression on True
  590. )
  591. self._model_is_compiled = True
  592. def _warm_up_vg_propagation(self, inference_state, start_frame_idx=0):
  593. # use different tracking score thresholds for each round to simulate different number of output objects
  594. num_objects_list = range(self.num_obj_for_compile + 1)
  595. new_det_score_thresh_list = [0.3, 0.5, 0.7]
  596. num_rounds = len(new_det_score_thresh_list)
  597. orig_new_det_thresh = self.new_det_thresh
  598. for i, thresh in enumerate(new_det_score_thresh_list):
  599. self.new_det_thresh = thresh
  600. for num_objects in num_objects_list:
  601. logger.info(f"{i + 1}/{num_rounds} warming up model compilation")
  602. self.add_prompt(
  603. inference_state, frame_idx=start_frame_idx, text_str="cat"
  604. )
  605. logger.info(
  606. f"{i + 1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
  607. )
  608. inference_state = self.add_fake_objects_to_inference_state(
  609. inference_state, num_objects, frame_idx=start_frame_idx
  610. )
  611. inference_state["tracker_metadata"]["rank0_metadata"].update(
  612. {
  613. "masklet_confirmation": {
  614. "status": np.zeros(num_objects, dtype=np.int64),
  615. "consecutive_det_num": np.zeros(
  616. num_objects, dtype=np.int64
  617. ),
  618. }
  619. }
  620. )
  621. for _ in self.propagate_in_video(
  622. inference_state, start_frame_idx, reverse=False
  623. ):
  624. pass
  625. for _ in self.propagate_in_video(
  626. inference_state, start_frame_idx, reverse=True
  627. ):
  628. pass
  629. self.reset_state(inference_state)
  630. logger.info(
  631. f"{i + 1}/{num_rounds} warming up model compilation -- completed round {i + 1} out of {num_rounds}"
  632. )
  633. # Warm up Tracker memory encoder with varying input shapes
  634. num_iters = 3
  635. feat_size = self.tracker.sam_image_embedding_size**2 # 72 * 72 = 5184
  636. hidden_dim = self.tracker.hidden_dim # 256
  637. mem_dim = self.tracker.mem_dim # 64
  638. for _ in tqdm(range(num_iters)):
  639. for b in range(1, self.num_obj_for_compile + 1):
  640. for i in range(
  641. 1,
  642. self.tracker.max_cond_frames_in_attn + self.tracker.num_maskmem,
  643. ):
  644. for j in range(
  645. self.tracker.max_cond_frames_in_attn
  646. + self.tracker.max_obj_ptrs_in_encoder
  647. ):
  648. num_obj_ptr_tokens = (hidden_dim // mem_dim) * j
  649. src = torch.randn(feat_size, b, hidden_dim, device=self.device)
  650. src_pos = torch.randn(
  651. feat_size, b, hidden_dim, device=self.device
  652. )
  653. prompt = torch.randn(
  654. feat_size * i + num_obj_ptr_tokens,
  655. b,
  656. mem_dim,
  657. device=self.device,
  658. )
  659. prompt_pos = torch.randn(
  660. feat_size * i + num_obj_ptr_tokens,
  661. b,
  662. mem_dim,
  663. device=self.device,
  664. )
  665. self.tracker.transformer.encoder.forward(
  666. src=src,
  667. src_pos=src_pos,
  668. prompt=prompt,
  669. prompt_pos=prompt_pos,
  670. num_obj_ptr_tokens=num_obj_ptr_tokens,
  671. )
  672. self.new_det_thresh = orig_new_det_thresh
  673. return inference_state
  674. def add_fake_objects_to_inference_state(
  675. self, inference_state, num_objects, frame_idx
  676. ):
  677. new_det_obj_ids_local = np.arange(num_objects)
  678. high_res_H, high_res_W = (
  679. self.tracker.maskmem_backbone.mask_downsampler.interpol_size
  680. )
  681. new_det_masks = torch.ones(
  682. len(new_det_obj_ids_local), high_res_H, high_res_W
  683. ).to(self.device)
  684. inference_state["tracker_inference_states"] = self._tracker_add_new_objects(
  685. frame_idx=frame_idx,
  686. num_frames=inference_state["num_frames"],
  687. new_obj_ids=new_det_obj_ids_local,
  688. new_obj_masks=new_det_masks,
  689. tracker_states_local=inference_state["tracker_inference_states"],
  690. orig_vid_height=inference_state["orig_height"],
  691. orig_vid_width=inference_state["orig_width"],
  692. feature_cache=inference_state["feature_cache"],
  693. )
  694. # Synthesize obj_id_to_mask data for cached_frame_outputs to support _build_tracker_output during warmup
  695. obj_id_to_mask = {}
  696. if num_objects > 0:
  697. H_video = inference_state["orig_height"]
  698. W_video = inference_state["orig_width"]
  699. video_res_masks = F.interpolate(
  700. new_det_masks.unsqueeze(1), # Add channel dimension for interpolation
  701. size=(H_video, W_video),
  702. mode="bilinear",
  703. align_corners=False,
  704. ) # (num_objects, 1, H_video, W_video)
  705. for i, obj_id in enumerate(new_det_obj_ids_local):
  706. obj_id_to_mask[obj_id] = (video_res_masks[i] > 0.0).to(torch.bool)
  707. if self.rank == 0:
  708. for fidx in range(inference_state["num_frames"]):
  709. self._cache_frame_outputs(inference_state, fidx, obj_id_to_mask)
  710. inference_state["tracker_metadata"].update(
  711. {
  712. "obj_ids_per_gpu": [np.arange(num_objects)],
  713. "obj_ids_all_gpu": np.arange(num_objects), # Same as 1 GPU
  714. "num_obj_per_gpu": [num_objects],
  715. "obj_id_to_score": {i: 1.0 for i in range(num_objects)},
  716. "max_obj_id": num_objects,
  717. "rank0_metadata": {
  718. "masklet_confirmation": {
  719. "status": np.zeros(num_objects, dtype=np.int64),
  720. "consecutive_det_num": np.zeros(num_objects, dtype=np.int64),
  721. },
  722. "removed_obj_ids": set(),
  723. "suppressed_obj_ids": defaultdict(set),
  724. },
  725. }
  726. )
  727. return inference_state
  728. @torch.inference_mode()
  729. @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
  730. def warm_up_compilation(self):
  731. """
  732. Warm up the model by running a dummy inference to compile the model. This is
  733. useful to avoid the compilation overhead in the first inference call.
  734. """
  735. if not self.compile_model:
  736. return
  737. self._warm_up_complete = False
  738. if self.device.type != "cuda":
  739. raise RuntimeError(
  740. f"The model must be on CUDA for warm-up compilation, got {self.device=}."
  741. )
  742. # temporally set to single GPU temporarily for warm-up compilation
  743. orig_rank = self.rank
  744. orig_world_size = self.world_size
  745. self.rank = self.detector.rank = 0
  746. self.world_size = self.detector.world_size = 1
  747. orig_recondition_every_nth_frame = self.recondition_every_nth_frame
  748. # self.recondition_every_nth_frame = 2
  749. # Get a random video
  750. inference_state = self.init_state(resource_path="<load-dummy-video-30>")
  751. start_frame_idx = 0
  752. # Run basic propagation warm-up
  753. inference_state = self._warm_up_vg_propagation(inference_state, start_frame_idx)
  754. logger.info("Warm-up compilation completed.")
  755. # revert to the original GPU and rank
  756. self.rank = self.detector.rank = orig_rank
  757. self.world_size = self.detector.world_size = orig_world_size
  758. self.recondition_every_nth_frame = orig_recondition_every_nth_frame
  759. self._warm_up_complete = True
  760. self.tracker.transformer.encoder.forward.set_logging(True)
  761. @torch.inference_mode()
  762. def add_prompt(
  763. self,
  764. inference_state,
  765. frame_idx,
  766. text_str=None,
  767. boxes_xywh=None,
  768. box_labels=None,
  769. ):
  770. """
  771. Add text, point or box prompts on a single frame. This method returns the inference
  772. outputs only on the prompted frame.
  773. Note that text prompts are NOT associated with a particular frame (i.e. they apply
  774. to all frames). However, we only run inference on the frame specified in `frame_idx`.
  775. """
  776. logger.debug("Running add_prompt on frame %d", frame_idx)
  777. num_frames = inference_state["num_frames"]
  778. assert text_str is not None or boxes_xywh is not None, (
  779. "at least one type of prompt (text, boxes) must be provided"
  780. )
  781. assert 0 <= frame_idx < num_frames, (
  782. f"{frame_idx=} is out of range for a total of {num_frames} frames"
  783. )
  784. # since it's a semantic prompt, we start over
  785. self.reset_state(inference_state)
  786. # 1) add text prompt
  787. if text_str is not None and text_str != "visual":
  788. inference_state["text_prompt"] = text_str
  789. inference_state["input_batch"].find_text_batch[0] = text_str
  790. text_id = self.TEXT_ID_FOR_TEXT
  791. else:
  792. inference_state["text_prompt"] = None
  793. inference_state["input_batch"].find_text_batch[0] = "<text placeholder>"
  794. text_id = self.TEXT_ID_FOR_VISUAL
  795. for t in range(inference_state["num_frames"]):
  796. inference_state["input_batch"].find_inputs[t].text_ids[...] = text_id
  797. # 2) handle box prompt
  798. assert (boxes_xywh is not None) == (box_labels is not None)
  799. if boxes_xywh is not None:
  800. boxes_xywh = torch.as_tensor(boxes_xywh, dtype=torch.float32)
  801. box_labels = torch.as_tensor(box_labels, dtype=torch.long)
  802. # input boxes are expected to be [xmin, ymin, width, height] format
  803. # in normalized coordinates of range 0~1, similar to FA
  804. assert boxes_xywh.dim() == 2
  805. assert boxes_xywh.size(0) > 0 and boxes_xywh.size(-1) == 4
  806. assert box_labels.dim() == 1 and box_labels.size(0) == boxes_xywh.size(0)
  807. boxes_cxcywh = box_xywh_to_cxcywh(boxes_xywh)
  808. assert (boxes_xywh >= 0).all().item() and (boxes_xywh <= 1).all().item()
  809. assert (boxes_cxcywh >= 0).all().item() and (boxes_cxcywh <= 1).all().item()
  810. new_box_input = boxes_cxcywh, box_labels
  811. inference_state["per_frame_raw_box_input"][frame_idx] = new_box_input
  812. # handle the case of visual prompt (also added as an input box from the UI)
  813. boxes_cxcywh, box_labels, geometric_prompt = self._get_visual_prompt(
  814. inference_state, frame_idx, boxes_cxcywh, box_labels
  815. )
  816. inference_state["per_frame_geometric_prompt"][frame_idx] = geometric_prompt
  817. out = self._run_single_frame_inference(
  818. inference_state, frame_idx, reverse=False
  819. )
  820. return frame_idx, self._postprocess_output(inference_state, out)
  821. @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
  822. def forward(self, input: BatchedDatapoint, is_inference: bool = False):
  823. """This method is only used for benchmark eval (not used in the demo)."""
  824. # set the model to single GPU for benchmark evaluation (to be compatible with trainer)
  825. orig_rank = self.rank
  826. orig_world_size = self.world_size
  827. self.rank = self.detector.rank = 0
  828. self.world_size = self.detector.world_size = 1
  829. # get data
  830. text_prompt_ids = input.find_metadatas[0].original_category_id
  831. text_prompt_list = input.find_text_batch
  832. # loop over txt prompts
  833. tracking_res = defaultdict(dict) # frame_idx --> {obj_id: mask}
  834. scores_labels = defaultdict(tuple) # obj_id --> (score, text_prompt_id)
  835. inference_state = self.init_state(resource_path=input.raw_images)
  836. for prompt_id, prompt in zip(text_prompt_ids, text_prompt_list):
  837. self.add_prompt(inference_state, frame_idx=0, text_str=prompt)
  838. start_obj_id = max(scores_labels.keys(), default=-1) + 1 # prev max + 1
  839. # propagate the prompts
  840. obj_ids_this_prompt = set()
  841. for frame_idx, out in self.propagate_in_video(
  842. inference_state,
  843. start_frame_idx=0,
  844. max_frame_num_to_track=inference_state["num_frames"],
  845. reverse=False,
  846. ):
  847. current_frame_res = tracking_res[frame_idx]
  848. for obj_id, mask in zip(out["out_obj_ids"], out["out_binary_masks"]):
  849. mask_tensor = torch.tensor(mask[None], dtype=torch.bool)
  850. current_frame_res[obj_id + start_obj_id] = mask_tensor
  851. obj_ids_this_prompt.update(current_frame_res.keys())
  852. obj_id_to_score = inference_state["tracker_metadata"]["obj_id_to_score"]
  853. for obj_id, score in obj_id_to_score.items():
  854. if obj_id + start_obj_id in obj_ids_this_prompt:
  855. score_tensor = torch.tensor(score, dtype=torch.float32)
  856. scores_labels[obj_id + start_obj_id] = (score_tensor, prompt_id)
  857. self.reset_state(inference_state)
  858. video_id = input.find_metadatas[0].original_image_id[0].cpu().item()
  859. preds = self.prep_for_evaluator(input.raw_images, tracking_res, scores_labels)
  860. # revert the model to the original GPU and rank
  861. self.rank = self.detector.rank = orig_rank
  862. self.world_size = self.detector.world_size = orig_world_size
  863. return {video_id: preds}
  864. def back_convert(self, targets):
  865. # Needed for retraining compatibility with trainer
  866. return targets
  867. class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
  868. def __init__(
  869. self,
  870. use_prev_mem_frame=False,
  871. use_stateless_refinement=False,
  872. refinement_detector_cond_frame_removal_window=16,
  873. **kwargs,
  874. ):
  875. """
  876. use_prev_mem_frame: bool, whether to condition on previous memory frames for adding points
  877. use_stateless_refinement: bool, whether to enable stateless refinement behavior
  878. refinement_detector_cond_frame_removal_window: int, we remove a detector conditioning frame if it
  879. is within this many frames of a user refined frame. Set to a large value (e.g. 10000) to
  880. always remove detector conditioning frames if there is any user refinement in the video.
  881. """
  882. super().__init__(**kwargs)
  883. self.use_prev_mem_frame = use_prev_mem_frame
  884. self.use_stateless_refinement = use_stateless_refinement
  885. self.refinement_detector_cond_frame_removal_window = (
  886. refinement_detector_cond_frame_removal_window
  887. )
  888. def _init_new_tracker_state(self, inference_state):
  889. return self.tracker.init_state(
  890. cached_features=inference_state["feature_cache"],
  891. video_height=inference_state["orig_height"],
  892. video_width=inference_state["orig_width"],
  893. num_frames=inference_state["num_frames"],
  894. )
  895. @torch.inference_mode()
  896. def propagate_in_video(
  897. self,
  898. inference_state,
  899. start_frame_idx=None,
  900. max_frame_num_to_track=None,
  901. reverse=False,
  902. ):
  903. # step 1: check which type of propagation to run, should be the same for all GPUs.
  904. propagation_type, obj_ids = self.parse_action_history_for_propagation(
  905. inference_state
  906. )
  907. self.add_action_history(
  908. inference_state,
  909. action_type=propagation_type,
  910. obj_ids=obj_ids,
  911. frame_idx=start_frame_idx,
  912. )
  913. # step 2: run full VG propagation
  914. if propagation_type == "propagation_full":
  915. logger.debug(f"Running full VG propagation (reverse={reverse}).")
  916. yield from super().propagate_in_video(
  917. inference_state,
  918. start_frame_idx=start_frame_idx,
  919. max_frame_num_to_track=max_frame_num_to_track,
  920. reverse=reverse,
  921. )
  922. return
  923. # step 3: run Tracker partial propagation or direct fetch existing predictions
  924. assert propagation_type in ["propagation_partial", "propagation_fetch"]
  925. logger.debug(
  926. f"Running Tracker propagation for objects {obj_ids} and merging it with existing VG predictions (reverse={reverse})."
  927. if propagation_type == "propagation_partial"
  928. else f"Fetching existing VG predictions without running any propagation (reverse={reverse})."
  929. )
  930. processing_order, _ = self._get_processing_order(
  931. inference_state,
  932. start_frame_idx=start_frame_idx,
  933. max_frame_num_to_track=max_frame_num_to_track,
  934. reverse=reverse,
  935. )
  936. tracker_metadata = inference_state["tracker_metadata"]
  937. # if fetch just return from output
  938. if propagation_type == "propagation_fetch":
  939. for frame_idx in tqdm(processing_order):
  940. if self.rank == 0:
  941. obj_id_to_mask = inference_state["cached_frame_outputs"].get(
  942. frame_idx, {}
  943. )
  944. # post processing - remove suppressed obj_ids
  945. obj_id_to_score = tracker_metadata["obj_id_to_score"]
  946. suppressed_obj_ids = tracker_metadata["rank0_metadata"][
  947. "suppressed_obj_ids"
  948. ][frame_idx]
  949. obj_id_to_tracker_score = tracker_metadata[
  950. "obj_id_to_tracker_score_frame_wise"
  951. ][frame_idx]
  952. out = {
  953. "obj_id_to_mask": obj_id_to_mask,
  954. "obj_id_to_score": obj_id_to_score,
  955. "obj_id_to_tracker_score": obj_id_to_tracker_score,
  956. }
  957. yield (
  958. frame_idx,
  959. self._postprocess_output(
  960. inference_state, out, suppressed_obj_ids=suppressed_obj_ids
  961. ),
  962. )
  963. else:
  964. yield frame_idx, None
  965. return
  966. # get Tracker inference states containing selected obj_ids
  967. if propagation_type == "propagation_partial":
  968. # can be empty for GPUs where objects are not in their inference states
  969. tracker_states_local = self._get_tracker_inference_states_by_obj_ids(
  970. inference_state, obj_ids
  971. )
  972. for tracker_state in tracker_states_local:
  973. self.tracker.propagate_in_video_preflight(
  974. tracker_state, run_mem_encoder=True
  975. )
  976. for frame_idx in tqdm(processing_order):
  977. # run Tracker propagation
  978. if propagation_type == "propagation_partial":
  979. self._prepare_backbone_feats(inference_state, frame_idx, reverse)
  980. obj_ids_local, low_res_masks_local, tracker_scores_local = (
  981. self._propogate_tracker_one_frame_local_gpu(
  982. tracker_states_local,
  983. frame_idx=frame_idx,
  984. reverse=reverse,
  985. run_mem_encoder=True,
  986. )
  987. )
  988. # broadcast refined object tracker scores and masks to all GPUs
  989. # handle multiple objects that can be located on different GPUs
  990. refined_obj_data = {} # obj_id -> (score, mask_video_res)
  991. # Collect data for objects on this GPU
  992. local_obj_data = {}
  993. for obj_id in obj_ids:
  994. obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id)
  995. if self.rank == obj_rank and obj_id in obj_ids_local:
  996. refined_obj_idx = obj_ids_local.index(obj_id)
  997. refined_mask_low_res = low_res_masks_local[
  998. refined_obj_idx
  999. ] # (H_low_res, W_low_res)
  1000. refined_score = tracker_scores_local[refined_obj_idx]
  1001. # Keep low resolution for broadcasting to reduce communication cost
  1002. local_obj_data[obj_id] = (refined_score, refined_mask_low_res)
  1003. # Broadcast data from each GPU that has refined objects
  1004. if self.world_size > 1:
  1005. for obj_id in obj_ids:
  1006. obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id)
  1007. if self.rank == obj_rank:
  1008. # This GPU has the object, broadcast its data
  1009. data_to_broadcast = local_obj_data.get(obj_id, None)
  1010. data_list = [
  1011. (data_to_broadcast[0].cpu(), data_to_broadcast[1].cpu())
  1012. ]
  1013. self.broadcast_python_obj_cpu(data_list, src=obj_rank)
  1014. if data_to_broadcast is not None:
  1015. refined_obj_data[obj_id] = data_to_broadcast
  1016. elif self.rank != obj_rank:
  1017. # This GPU doesn't have the object, receive data
  1018. data_list = [None]
  1019. self.broadcast_python_obj_cpu(data_list, src=obj_rank)
  1020. refined_obj_data[obj_id] = (
  1021. data_list[0][0].to(self.device),
  1022. data_list[0][1].to(self.device),
  1023. )
  1024. else:
  1025. # Single GPU case
  1026. refined_obj_data = local_obj_data
  1027. # Update Tracker scores for all refined objects
  1028. for obj_id, (refined_score, _) in refined_obj_data.items():
  1029. tracker_metadata["obj_id_to_tracker_score_frame_wise"][
  1030. frame_idx
  1031. ].update({obj_id: refined_score.item()})
  1032. if self.rank == 0:
  1033. # get predictions from Tracker inference states, it includes the original
  1034. # VG predictions and the refined predictions from interactivity.
  1035. # Prepare refined masks dictionary - upscale to video resolution after broadcast
  1036. refined_obj_id_to_mask = {}
  1037. for obj_id, (_, refined_mask_low_res) in refined_obj_data.items():
  1038. refined_mask_video_res = (
  1039. self._convert_low_res_mask_to_video_res(
  1040. refined_mask_low_res, inference_state
  1041. )
  1042. ) # (1, H_video, W_video) bool
  1043. refined_obj_id_to_mask[obj_id] = refined_mask_video_res
  1044. obj_id_to_mask = self._build_tracker_output(
  1045. inference_state, frame_idx, refined_obj_id_to_mask
  1046. )
  1047. out = {
  1048. "obj_id_to_mask": obj_id_to_mask,
  1049. "obj_id_to_score": tracker_metadata["obj_id_to_score"],
  1050. "obj_id_to_tracker_score": tracker_metadata[
  1051. "obj_id_to_tracker_score_frame_wise"
  1052. ][frame_idx],
  1053. }
  1054. suppressed_obj_ids = tracker_metadata["rank0_metadata"][
  1055. "suppressed_obj_ids"
  1056. ][frame_idx]
  1057. self._cache_frame_outputs(
  1058. inference_state,
  1059. frame_idx,
  1060. obj_id_to_mask,
  1061. suppressed_obj_ids=suppressed_obj_ids,
  1062. )
  1063. suppressed_obj_ids = tracker_metadata["rank0_metadata"][
  1064. "suppressed_obj_ids"
  1065. ][frame_idx]
  1066. yield (
  1067. frame_idx,
  1068. self._postprocess_output(
  1069. inference_state, out, suppressed_obj_ids=suppressed_obj_ids
  1070. ),
  1071. )
  1072. else:
  1073. yield frame_idx, None
  1074. def add_action_history(
  1075. self, inference_state, action_type, frame_idx=None, obj_ids=None
  1076. ):
  1077. """
  1078. action_history is used to automatically decide what to do during propagation.
  1079. action_type: one of ["add", "remove", "refine"] + ["propagation_full", "propagation_partial", "propagation_fetch"]
  1080. """
  1081. instance_actions = ["add", "remove", "refine"]
  1082. propagation_actions = [
  1083. "propagation_full",
  1084. "propagation_partial",
  1085. "propagation_fetch",
  1086. ]
  1087. assert action_type in instance_actions + propagation_actions, (
  1088. f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
  1089. )
  1090. action = {
  1091. "type": action_type,
  1092. "frame_idx": frame_idx,
  1093. "obj_ids": obj_ids,
  1094. }
  1095. inference_state["action_history"].append(action)
  1096. def _has_object_been_refined(self, inference_state, obj_id):
  1097. action_history = inference_state["action_history"]
  1098. for action in action_history:
  1099. if action["type"] in ["add", "refine"] and action.get("obj_ids"):
  1100. if obj_id in action["obj_ids"]:
  1101. return True
  1102. return False
  1103. def parse_action_history_for_propagation(self, inference_state):
  1104. """
  1105. Parse the actions in history before the last propagation and prepare for the next propagation.
  1106. We support multiple actions (add/remove/refine) between two propagations. If we had an action
  1107. history similar to this ["propagate", "add", "refine", "remove", "add"], the next propagation
  1108. would remove the removed object, and also propagate the two added/refined objects.
  1109. Returns:
  1110. propagation_type: one of ["propagation_full", "propagation_partial", "propagation_fetch"]
  1111. - "propagation_full": run VG propagation for all objects
  1112. - "propagation_partial": run Tracker propagation for selected objects, useful for add/refine actions
  1113. - "propagation_fetch": fetch existing VG predictions without running any propagation
  1114. obj_ids: list of object ids to run Tracker propagation on if propagation_type is "propagation_partial".
  1115. """
  1116. action_history = inference_state["action_history"]
  1117. if len(action_history) == 0:
  1118. # we run propagation for the first time
  1119. return "propagation_full", None
  1120. if "propagation" in action_history[-1]["type"]:
  1121. if action_history[-1]["type"] in ["propagation_fetch"]:
  1122. # last propagation is direct fetch, we fetch existing predictions
  1123. return "propagation_fetch", None
  1124. elif action_history[-1]["type"] in [
  1125. "propagation_partial",
  1126. "propagation_full",
  1127. ]:
  1128. # we do fetch prediction if we have already run propagation twice or we have run
  1129. # propagation once and it is from the first frame or last frame.
  1130. if (
  1131. len(action_history) > 1
  1132. and action_history[-2]["type"]
  1133. in ["propagation_partial", "propagation_full"]
  1134. ) or action_history[-1]["frame_idx"] in [
  1135. 0,
  1136. inference_state["num_frames"] - 1,
  1137. ]:
  1138. # we have run both forward and backward partial/full propagation
  1139. return "propagation_fetch", None
  1140. else:
  1141. # we have run partial/full forward or backward propagation once, need run it for the rest of the frames
  1142. return action_history[-1]["type"], action_history[-1]["obj_ids"]
  1143. # parse actions since last propagation
  1144. obj_ids = []
  1145. for action in action_history[::-1]:
  1146. if "propagation" in action["type"]:
  1147. # we reached the last propagation action, stop parsing
  1148. break
  1149. if action["type"] in ["add", "refine"]:
  1150. obj_ids.extend(action["obj_ids"])
  1151. # else action["type"] == "remove": noop
  1152. obj_ids = list(set(obj_ids)) if len(obj_ids) > 0 else None
  1153. propagation_type = (
  1154. "propagation_partial" if obj_ids is not None else "propagation_fetch"
  1155. )
  1156. return propagation_type, obj_ids
  1157. def remove_object(self, inference_state, obj_id, is_user_action=False):
  1158. """
  1159. We try to remove object from tracker states on every GPU, it will do nothing
  1160. for states without this object.
  1161. """
  1162. obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id)
  1163. assert obj_rank is not None, f"Object {obj_id} not found in any GPU."
  1164. tracker_states_local = inference_state["tracker_inference_states"]
  1165. if self.rank == obj_rank:
  1166. self._tracker_remove_object(tracker_states_local, obj_id)
  1167. if is_user_action:
  1168. self.add_action_history(
  1169. inference_state, action_type="remove", obj_ids=[obj_id]
  1170. )
  1171. # update metadata
  1172. tracker_metadata = inference_state["tracker_metadata"]
  1173. _obj_ids = tracker_metadata["obj_ids_per_gpu"][obj_rank]
  1174. tracker_metadata["obj_ids_per_gpu"][obj_rank] = _obj_ids[_obj_ids != obj_id]
  1175. tracker_metadata["num_obj_per_gpu"][obj_rank] = len(
  1176. tracker_metadata["obj_ids_per_gpu"][obj_rank]
  1177. )
  1178. tracker_metadata["obj_ids_all_gpu"] = np.concatenate(
  1179. tracker_metadata["obj_ids_per_gpu"]
  1180. )
  1181. tracker_metadata["obj_id_to_score"].pop(obj_id, None)
  1182. # tracker_metadata["max_obj_id"] # we do not reuse the object id, so we do not update it here
  1183. # Clean up cached frame outputs to remove references to the deleted object
  1184. if "cached_frame_outputs" in inference_state:
  1185. for frame_idx in inference_state["cached_frame_outputs"]:
  1186. frame_cache = inference_state["cached_frame_outputs"][frame_idx]
  1187. if obj_id in frame_cache:
  1188. del frame_cache[obj_id]
  1189. def _get_gpu_id_by_obj_id(self, inference_state, obj_id):
  1190. """
  1191. Locate GPU ID for a given object.
  1192. """
  1193. obj_ids_per_gpu = inference_state["tracker_metadata"]["obj_ids_per_gpu"]
  1194. for rank, obj_ids in enumerate(obj_ids_per_gpu):
  1195. if obj_id in obj_ids:
  1196. return rank
  1197. return None # object not found in any GPU
  1198. def _get_tracker_inference_states_by_obj_ids(self, inference_state, obj_ids):
  1199. """
  1200. Get the Tracker inference states that contain the given object ids.
  1201. This is used to run partial Tracker propagation on a single object/bucket.
  1202. Possibly multiple or zero states can be returned.
  1203. """
  1204. states = [
  1205. state
  1206. for state in inference_state["tracker_inference_states"]
  1207. if set(obj_ids) & set(state["obj_ids"])
  1208. ]
  1209. return states
  1210. def _prepare_backbone_feats(self, inference_state, frame_idx, reverse):
  1211. input_batch = inference_state["input_batch"]
  1212. feature_cache = inference_state["feature_cache"]
  1213. num_frames = inference_state["num_frames"]
  1214. geometric_prompt = (
  1215. inference_state["constants"]["empty_geometric_prompt"]
  1216. if inference_state["per_frame_geometric_prompt"][frame_idx] is None
  1217. else inference_state["per_frame_geometric_prompt"][frame_idx]
  1218. )
  1219. _ = self.run_backbone_and_detection(
  1220. frame_idx=frame_idx,
  1221. num_frames=num_frames,
  1222. input_batch=input_batch,
  1223. geometric_prompt=geometric_prompt,
  1224. feature_cache=feature_cache,
  1225. reverse=reverse,
  1226. allow_new_detections=True,
  1227. )
  1228. @torch.inference_mode()
  1229. def add_prompt(
  1230. self,
  1231. inference_state,
  1232. frame_idx,
  1233. text_str=None,
  1234. boxes_xywh=None,
  1235. box_labels=None,
  1236. points=None,
  1237. point_labels=None,
  1238. obj_id=None,
  1239. rel_coordinates=True,
  1240. ):
  1241. if points is not None:
  1242. # Tracker instance prompts
  1243. assert text_str is None and boxes_xywh is None, (
  1244. "When points are provided, text_str and boxes_xywh must be None."
  1245. )
  1246. assert obj_id is not None, (
  1247. "When points are provided, obj_id must be provided."
  1248. )
  1249. return self.add_tracker_new_points(
  1250. inference_state,
  1251. frame_idx,
  1252. obj_id=obj_id,
  1253. points=points,
  1254. labels=point_labels,
  1255. rel_coordinates=rel_coordinates,
  1256. use_prev_mem_frame=self.use_prev_mem_frame,
  1257. )
  1258. else:
  1259. # SAM3 prompts
  1260. return super().add_prompt(
  1261. inference_state,
  1262. frame_idx,
  1263. text_str=text_str,
  1264. boxes_xywh=boxes_xywh,
  1265. box_labels=box_labels,
  1266. )
  1267. @torch.inference_mode()
  1268. def add_tracker_new_points(
  1269. self,
  1270. inference_state,
  1271. frame_idx,
  1272. obj_id,
  1273. points,
  1274. labels,
  1275. rel_coordinates=True,
  1276. use_prev_mem_frame=False,
  1277. ):
  1278. """Add a new point prompt to Tracker. Suppporting instance refinement to existing
  1279. objects by passing existing obj_id or adding a new object by passing a new obj_id.
  1280. use_prev_mem_frame=False to disable cross attention to previous memory frames.
  1281. Every GPU returns the same results, and results should contain all masks including
  1282. these masks not refined or not added by the current user points.
  1283. """
  1284. assert obj_id is not None, "obj_id must be provided to add new points"
  1285. tracker_metadata = inference_state["tracker_metadata"]
  1286. if tracker_metadata == {}:
  1287. # initialize masklet metadata if it's uninitialized (empty dict)
  1288. tracker_metadata.update(self._initialize_metadata())
  1289. obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id)
  1290. # prepare feature
  1291. self._prepare_backbone_feats(inference_state, frame_idx, reverse=False)
  1292. object_has_been_refined = self._has_object_been_refined(inference_state, obj_id)
  1293. if (
  1294. obj_rank is not None
  1295. and self.use_stateless_refinement
  1296. and not object_has_been_refined
  1297. ):
  1298. # The first time we start refinement on the object, we remove it.
  1299. logger.debug(
  1300. f"[rank={self.rank}] Removing object {obj_id} before refinement."
  1301. )
  1302. self.remove_object(inference_state, obj_id, is_user_action=False)
  1303. obj_rank = None
  1304. if obj_rank is None:
  1305. # new object, we assign it a GPU and create a new inference state if limit allows
  1306. num_prev_obj = np.sum(tracker_metadata["num_obj_per_gpu"])
  1307. if num_prev_obj >= self.max_num_objects:
  1308. logger.warning(
  1309. f"add_tracker_new_points: cannot add a new object as we are already tracking {num_prev_obj=} "
  1310. f"masklets (under {self.max_num_objects=})"
  1311. )
  1312. obj_ids = []
  1313. H_low_res = W_low_res = self.tracker.low_res_mask_size
  1314. H_video_res = inference_state["orig_height"]
  1315. W_video_res = inference_state["orig_width"]
  1316. low_res_masks = torch.zeros(0, 1, H_low_res, W_low_res)
  1317. video_res_masks = torch.zeros(0, 1, H_video_res, W_video_res)
  1318. return frame_idx, obj_ids, low_res_masks, video_res_masks
  1319. new_det_gpu_ids = self._assign_new_det_to_gpus(
  1320. new_det_num=1,
  1321. prev_workload_per_gpu=tracker_metadata["num_obj_per_gpu"],
  1322. )
  1323. obj_rank = new_det_gpu_ids[0]
  1324. # get tracker inference state for the new object
  1325. if self.rank == obj_rank:
  1326. # for batched inference, we create a new inference state
  1327. tracker_state = self._init_new_tracker_state(inference_state)
  1328. inference_state["tracker_inference_states"].append(tracker_state)
  1329. # update metadata
  1330. tracker_metadata["obj_ids_per_gpu"][obj_rank] = np.concatenate(
  1331. [
  1332. tracker_metadata["obj_ids_per_gpu"][obj_rank],
  1333. np.array([obj_id], dtype=np.int64),
  1334. ]
  1335. )
  1336. tracker_metadata["num_obj_per_gpu"][obj_rank] = len(
  1337. tracker_metadata["obj_ids_per_gpu"][obj_rank]
  1338. )
  1339. tracker_metadata["obj_ids_all_gpu"] = np.concatenate(
  1340. tracker_metadata["obj_ids_per_gpu"]
  1341. )
  1342. tracker_metadata["max_obj_id"] = max(tracker_metadata["max_obj_id"], obj_id)
  1343. logger.debug(
  1344. f"[rank={self.rank}] Adding new object with id {obj_id} at frame {frame_idx}."
  1345. )
  1346. self.add_action_history(
  1347. inference_state, "add", frame_idx=frame_idx, obj_ids=[obj_id]
  1348. )
  1349. else:
  1350. # existing object, for refinement
  1351. if self.rank == obj_rank:
  1352. tracker_states = self._get_tracker_inference_states_by_obj_ids(
  1353. inference_state, [obj_id]
  1354. )
  1355. assert len(tracker_states) == 1, (
  1356. f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
  1357. )
  1358. tracker_state = tracker_states[0]
  1359. # log
  1360. logger.debug(
  1361. f"[rank={self.rank}] Refining existing object with id {obj_id} at frame {frame_idx}."
  1362. )
  1363. self.add_action_history(
  1364. inference_state, "refine", frame_idx=frame_idx, obj_ids=[obj_id]
  1365. )
  1366. # assign higher score to added/refined object
  1367. tracker_metadata["obj_id_to_score"][obj_id] = 1.0
  1368. tracker_metadata["obj_id_to_tracker_score_frame_wise"][frame_idx][obj_id] = 1.0
  1369. if self.rank == 0:
  1370. rank0_metadata = tracker_metadata.get("rank0_metadata", {})
  1371. if "removed_obj_ids" in rank0_metadata:
  1372. rank0_metadata["removed_obj_ids"].discard(obj_id)
  1373. if "suppressed_obj_ids" in rank0_metadata:
  1374. for frame_id in rank0_metadata["suppressed_obj_ids"]:
  1375. rank0_metadata["suppressed_obj_ids"][frame_id].discard(obj_id)
  1376. if "masklet_confirmation" in rank0_metadata:
  1377. obj_ids_all_gpu = tracker_metadata["obj_ids_all_gpu"]
  1378. obj_indices = np.where(obj_ids_all_gpu == obj_id)[0]
  1379. if len(obj_indices) > 0:
  1380. obj_idx = obj_indices[0]
  1381. if obj_idx < len(rank0_metadata["masklet_confirmation"]["status"]):
  1382. rank0_metadata["masklet_confirmation"]["status"][obj_idx] = 1
  1383. rank0_metadata["masklet_confirmation"]["consecutive_det_num"][
  1384. obj_idx
  1385. ] = self.masklet_confirmation_consecutive_det_thresh
  1386. if self.rank == obj_rank:
  1387. frame_idx, obj_ids, low_res_masks, video_res_masks = (
  1388. self.tracker.add_new_points(
  1389. inference_state=tracker_state,
  1390. frame_idx=frame_idx,
  1391. obj_id=obj_id,
  1392. points=points,
  1393. labels=labels,
  1394. clear_old_points=True,
  1395. rel_coordinates=rel_coordinates,
  1396. use_prev_mem_frame=use_prev_mem_frame,
  1397. )
  1398. )
  1399. if video_res_masks is not None and len(video_res_masks) > 0:
  1400. video_res_masks = fill_holes_in_mask_scores(
  1401. video_res_masks, # shape (N, 1, H_video, W_video)
  1402. max_area=self.fill_hole_area,
  1403. fill_holes=True,
  1404. remove_sprinkles=True,
  1405. )
  1406. # Since the mem encoder has already run for the current input points?
  1407. self.tracker.propagate_in_video_preflight(
  1408. tracker_state, run_mem_encoder=True
  1409. )
  1410. # Clear detector conditioning frames when user clicks are received to allow
  1411. # model updating masks on these frames. It is a noop if user is refining on the
  1412. # detector conditioning frames or adding new objects.
  1413. self.clear_detector_added_cond_frame_in_tracker(
  1414. tracker_state, obj_id, frame_idx
  1415. )
  1416. # fetch results from states and gather across GPUs
  1417. # Use optimized caching approach to avoid reprocessing unmodified objects
  1418. if self.rank == obj_rank and len(obj_ids) > 0:
  1419. new_mask_data = (video_res_masks[obj_ids.index(obj_id)] > 0.0).to(
  1420. torch.bool
  1421. )
  1422. else:
  1423. new_mask_data = None
  1424. # Broadcast the new mask data across all ranks for consistency
  1425. if self.world_size > 1:
  1426. data_list = [new_mask_data.cpu() if new_mask_data is not None else None]
  1427. self.broadcast_python_obj_cpu(data_list, src=obj_rank)
  1428. new_mask_data = data_list[0].to(self.device)
  1429. if self.rank == 0:
  1430. obj_id_to_mask = self._build_tracker_output(
  1431. inference_state,
  1432. frame_idx,
  1433. {obj_id: new_mask_data} if new_mask_data is not None else None,
  1434. )
  1435. # post processing - remove suppressed obj_ids
  1436. obj_id_to_score = tracker_metadata["obj_id_to_score"]
  1437. suppressed_obj_ids = tracker_metadata["rank0_metadata"][
  1438. "suppressed_obj_ids"
  1439. ][frame_idx]
  1440. obj_id_to_tracker_score = tracker_metadata[
  1441. "obj_id_to_tracker_score_frame_wise"
  1442. ][frame_idx]
  1443. out = {
  1444. "obj_id_to_mask": obj_id_to_mask,
  1445. "obj_id_to_score": obj_id_to_score,
  1446. "obj_id_to_tracker_score": obj_id_to_tracker_score,
  1447. }
  1448. self._cache_frame_outputs(
  1449. inference_state,
  1450. frame_idx,
  1451. obj_id_to_mask,
  1452. suppressed_obj_ids=suppressed_obj_ids,
  1453. )
  1454. return frame_idx, self._postprocess_output(
  1455. inference_state, out, suppressed_obj_ids=suppressed_obj_ids
  1456. )
  1457. else:
  1458. return frame_idx, None # no output on other GPUs
  1459. def _gather_obj_id_to_mask_across_gpus(self, inference_state, obj_id_to_mask_local):
  1460. """Gather obj_id_to_mask from all GPUs. Optionally resize the masks to the video resolution."""
  1461. tracker_metadata = inference_state["tracker_metadata"]
  1462. # concatenate the output masklets from all local inference states
  1463. H_mask = W_mask = self.tracker.low_res_mask_size
  1464. obj_ids_local = tracker_metadata["obj_ids_per_gpu"][self.rank]
  1465. low_res_masks_local = []
  1466. for obj_id in obj_ids_local:
  1467. if obj_id in obj_id_to_mask_local:
  1468. low_res_masks_local.append(obj_id_to_mask_local[obj_id])
  1469. else:
  1470. low_res_masks_local.append(
  1471. torch.full((H_mask, W_mask), -1024.0, device=self.device)
  1472. )
  1473. if len(low_res_masks_local) > 0:
  1474. low_res_masks_local = torch.stack(low_res_masks_local, dim=0) # (N, H, W)
  1475. assert low_res_masks_local.shape[1:] == (H_mask, W_mask)
  1476. else:
  1477. low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device)
  1478. # all-gather `low_res_masks_local` into `low_res_masks_global`
  1479. # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask)
  1480. if self.world_size > 1:
  1481. low_res_masks_local = low_res_masks_local.float().contiguous()
  1482. low_res_masks_peers = [
  1483. low_res_masks_local.new_empty(num_obj, H_mask, W_mask)
  1484. for num_obj in tracker_metadata["num_obj_per_gpu"]
  1485. ]
  1486. dist.all_gather(low_res_masks_peers, low_res_masks_local)
  1487. low_res_masks_global = torch.cat(low_res_masks_peers, dim=0)
  1488. else:
  1489. low_res_masks_global = low_res_masks_local
  1490. return low_res_masks_global
  1491. def _convert_low_res_mask_to_video_res(self, low_res_mask, inference_state):
  1492. """
  1493. Convert a low-res mask to video resolution, matching the format expected by _build_tracker_output.
  1494. Args:
  1495. low_res_mask: Tensor of shape (H_low_res, W_low_res)
  1496. inference_state: Contains video dimensions
  1497. Returns:
  1498. video_res_mask: Tensor of shape (1, H_video, W_video) bool
  1499. """
  1500. if low_res_mask is None:
  1501. return None
  1502. # Convert to 3D for interpolation: (H_low_res, W_low_res) -> (1, H_low_res, W_low_res)
  1503. low_res_mask_3d = low_res_mask.unsqueeze(0).unsqueeze(0)
  1504. # Get video dimensions
  1505. H_video = inference_state["orig_height"]
  1506. W_video = inference_state["orig_width"]
  1507. video_res_mask = F.interpolate(
  1508. low_res_mask_3d.float(),
  1509. size=(H_video, W_video),
  1510. mode="bilinear",
  1511. align_corners=False,
  1512. ) # (1, H_video, W_video)
  1513. # Convert to boolean - already in the right shape!
  1514. return (video_res_mask.squeeze(0) > 0.0).to(torch.bool)
  1515. def clear_detector_added_cond_frame_in_tracker(
  1516. self, tracker_state, obj_id, refined_frame_idx
  1517. ):
  1518. """Clear detector added conditioning frame if it is within a predefined window
  1519. of the refined frame. This allow model to update masks on these frames."""
  1520. obj_idx = self.tracker._obj_id_to_idx(tracker_state, obj_id)
  1521. mask_only_cond_frame_indices = []
  1522. window = self.refinement_detector_cond_frame_removal_window
  1523. for frame_idx in tracker_state["mask_inputs_per_obj"][obj_idx]:
  1524. if frame_idx not in tracker_state["point_inputs_per_obj"][obj_idx]:
  1525. # clear conditioning frames within a window of the refined frame
  1526. if abs(frame_idx - refined_frame_idx) <= window:
  1527. mask_only_cond_frame_indices.append(frame_idx)
  1528. # clear
  1529. if len(mask_only_cond_frame_indices) > 0:
  1530. for frame_idx in mask_only_cond_frame_indices:
  1531. # obj_ids_on_this_frame is essentially all obj_ids in the state
  1532. # since they are bucket batched
  1533. obj_ids_on_this_frame = tracker_state["obj_id_to_idx"].keys()
  1534. for obj_id2 in obj_ids_on_this_frame:
  1535. self.tracker.clear_all_points_in_frame(
  1536. tracker_state, frame_idx, obj_id2, need_output=False
  1537. )
  1538. logger.debug(
  1539. f"Cleared detector mask only conditioning frames ({mask_only_cond_frame_indices}) in Tracker."
  1540. )
  1541. return
  1542. def is_image_type(resource_path: str) -> bool:
  1543. if isinstance(resource_path, list):
  1544. return len(resource_path) == 1
  1545. return resource_path.lower().endswith(tuple(IMAGE_EXTS))