model_builder.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import os
  4. from typing import Optional
  5. import pkg_resources
  6. import torch
  7. import torch.nn as nn
  8. from huggingface_hub import hf_hub_download
  9. from iopath.common.file_io import g_pathmgr
  10. from sam3.model.decoder import (
  11. TransformerDecoder,
  12. TransformerDecoderLayer,
  13. TransformerDecoderLayerv2,
  14. TransformerEncoderCrossAttention,
  15. )
  16. from sam3.model.encoder import TransformerEncoderFusion, TransformerEncoderLayer
  17. from sam3.model.geometry_encoders import SequenceGeometryEncoder
  18. from sam3.model.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
  19. from sam3.model.memory import (
  20. CXBlock,
  21. SimpleFuser,
  22. SimpleMaskDownSampler,
  23. SimpleMaskEncoder,
  24. )
  25. from sam3.model.model_misc import (
  26. DotProductScoring,
  27. MLP,
  28. MultiheadAttentionWrapper as MultiheadAttention,
  29. TransformerWrapper,
  30. )
  31. from sam3.model.necks import Sam3DualViTDetNeck
  32. from sam3.model.position_encoding import PositionEmbeddingSine
  33. from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
  34. from sam3.model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU
  35. from sam3.model.sam3_tracking_predictor import Sam3TrackerPredictor
  36. from sam3.model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity
  37. from sam3.model.sam3_video_predictor import Sam3VideoPredictorMultiGPU
  38. from sam3.model.text_encoder_ve import VETextEncoder
  39. from sam3.model.tokenizer_ve import SimpleTokenizer
  40. from sam3.model.vitdet import ViT
  41. from sam3.model.vl_combiner import SAM3VLBackbone
  42. from sam3.sam.transformer import RoPEAttention
  43. # Setup TensorFloat-32 for Ampere GPUs if available
  44. def _setup_tf32() -> None:
  45. """Enable TensorFloat-32 for Ampere GPUs if available."""
  46. if torch.cuda.is_available():
  47. device_props = torch.cuda.get_device_properties(0)
  48. if device_props.major >= 8:
  49. torch.backends.cuda.matmul.allow_tf32 = True
  50. torch.backends.cudnn.allow_tf32 = True
  51. _setup_tf32()
  52. def _create_position_encoding(precompute_resolution=None):
  53. """Create position encoding for visual backbone."""
  54. return PositionEmbeddingSine(
  55. num_pos_feats=256,
  56. normalize=True,
  57. scale=None,
  58. temperature=10000,
  59. precompute_resolution=precompute_resolution,
  60. )
  61. def _create_vit_backbone(compile_mode=None):
  62. """Create ViT backbone for visual feature extraction."""
  63. return ViT(
  64. img_size=1008,
  65. pretrain_img_size=336,
  66. patch_size=14,
  67. embed_dim=1024,
  68. depth=32,
  69. num_heads=16,
  70. mlp_ratio=4.625,
  71. norm_layer="LayerNorm",
  72. drop_path_rate=0.1,
  73. qkv_bias=True,
  74. use_abs_pos=True,
  75. tile_abs_pos=True,
  76. global_att_blocks=(7, 15, 23, 31),
  77. rel_pos_blocks=(),
  78. use_rope=True,
  79. use_interp_rope=True,
  80. window_size=24,
  81. pretrain_use_cls_token=True,
  82. retain_cls_token=False,
  83. ln_pre=True,
  84. ln_post=False,
  85. return_interm_layers=False,
  86. bias_patch_embed=False,
  87. compile_mode=compile_mode,
  88. )
  89. def _create_vit_neck(position_encoding, vit_backbone, enable_inst_interactivity=False):
  90. """Create ViT neck for feature pyramid."""
  91. return Sam3DualViTDetNeck(
  92. position_encoding=position_encoding,
  93. d_model=256,
  94. scale_factors=[4.0, 2.0, 1.0, 0.5],
  95. trunk=vit_backbone,
  96. add_sam2_neck=enable_inst_interactivity,
  97. )
  98. def _create_vl_backbone(vit_neck, text_encoder):
  99. """Create visual-language backbone."""
  100. return SAM3VLBackbone(visual=vit_neck, text=text_encoder, scalp=1)
  101. def _create_transformer_encoder() -> TransformerEncoderFusion:
  102. """Create transformer encoder with its layer."""
  103. encoder_layer = TransformerEncoderLayer(
  104. activation="relu",
  105. d_model=256,
  106. dim_feedforward=2048,
  107. dropout=0.1,
  108. pos_enc_at_attn=True,
  109. pos_enc_at_cross_attn_keys=False,
  110. pos_enc_at_cross_attn_queries=False,
  111. pre_norm=True,
  112. self_attention=MultiheadAttention(
  113. num_heads=8,
  114. dropout=0.1,
  115. embed_dim=256,
  116. batch_first=True,
  117. ),
  118. cross_attention=MultiheadAttention(
  119. num_heads=8,
  120. dropout=0.1,
  121. embed_dim=256,
  122. batch_first=True,
  123. ),
  124. )
  125. encoder = TransformerEncoderFusion(
  126. layer=encoder_layer,
  127. num_layers=6,
  128. d_model=256,
  129. num_feature_levels=1,
  130. frozen=False,
  131. use_act_checkpoint=True,
  132. add_pooled_text_to_img_feat=False,
  133. pool_text_with_mask=True,
  134. )
  135. return encoder
  136. def _create_transformer_decoder() -> TransformerDecoder:
  137. """Create transformer decoder with its layer."""
  138. decoder_layer = TransformerDecoderLayer(
  139. activation="relu",
  140. d_model=256,
  141. dim_feedforward=2048,
  142. dropout=0.1,
  143. cross_attention=MultiheadAttention(
  144. num_heads=8,
  145. dropout=0.1,
  146. embed_dim=256,
  147. ),
  148. n_heads=8,
  149. use_text_cross_attention=True,
  150. )
  151. decoder = TransformerDecoder(
  152. layer=decoder_layer,
  153. num_layers=6,
  154. num_queries=200,
  155. return_intermediate=True,
  156. box_refine=True,
  157. num_o2m_queries=0,
  158. dac=True,
  159. boxRPB="log",
  160. d_model=256,
  161. frozen=False,
  162. interaction_layer=None,
  163. dac_use_selfatt_ln=True,
  164. resolution=1008,
  165. stride=14,
  166. use_act_checkpoint=True,
  167. presence_token=True,
  168. )
  169. return decoder
  170. def _create_dot_product_scoring():
  171. """Create dot product scoring module."""
  172. prompt_mlp = MLP(
  173. input_dim=256,
  174. hidden_dim=2048,
  175. output_dim=256,
  176. num_layers=2,
  177. dropout=0.1,
  178. residual=True,
  179. out_norm=nn.LayerNorm(256),
  180. )
  181. return DotProductScoring(d_model=256, d_proj=256, prompt_mlp=prompt_mlp)
  182. def _create_segmentation_head(compile_mode=None):
  183. """Create segmentation head with pixel decoder."""
  184. pixel_decoder = PixelDecoder(
  185. num_upsampling_stages=3,
  186. interpolation_mode="nearest",
  187. hidden_dim=256,
  188. compile_mode=compile_mode,
  189. )
  190. cross_attend_prompt = MultiheadAttention(
  191. num_heads=8,
  192. dropout=0,
  193. embed_dim=256,
  194. )
  195. segmentation_head = UniversalSegmentationHead(
  196. hidden_dim=256,
  197. upsampling_stages=3,
  198. aux_masks=False,
  199. presence_head=False,
  200. dot_product_scorer=None,
  201. act_ckpt=True,
  202. cross_attend_prompt=cross_attend_prompt,
  203. pixel_decoder=pixel_decoder,
  204. )
  205. return segmentation_head
  206. def _create_geometry_encoder():
  207. """Create geometry encoder with all its components."""
  208. # Create position encoding for geometry encoder
  209. geo_pos_enc = _create_position_encoding()
  210. # Create CX block for fuser
  211. cx_block = CXBlock(
  212. dim=256,
  213. kernel_size=7,
  214. padding=3,
  215. layer_scale_init_value=1.0e-06,
  216. use_dwconv=True,
  217. )
  218. # Create geometry encoder layer
  219. geo_layer = TransformerEncoderLayer(
  220. activation="relu",
  221. d_model=256,
  222. dim_feedforward=2048,
  223. dropout=0.1,
  224. pos_enc_at_attn=False,
  225. pre_norm=True,
  226. self_attention=MultiheadAttention(
  227. num_heads=8,
  228. dropout=0.1,
  229. embed_dim=256,
  230. batch_first=False,
  231. ),
  232. pos_enc_at_cross_attn_queries=False,
  233. pos_enc_at_cross_attn_keys=True,
  234. cross_attention=MultiheadAttention(
  235. num_heads=8,
  236. dropout=0.1,
  237. embed_dim=256,
  238. batch_first=False,
  239. ),
  240. )
  241. # Create geometry encoder
  242. input_geometry_encoder = SequenceGeometryEncoder(
  243. pos_enc=geo_pos_enc,
  244. encode_boxes_as_points=False,
  245. points_direct_project=True,
  246. points_pool=True,
  247. points_pos_enc=True,
  248. boxes_direct_project=True,
  249. boxes_pool=True,
  250. boxes_pos_enc=True,
  251. d_model=256,
  252. num_layers=3,
  253. layer=geo_layer,
  254. use_act_ckpt=True,
  255. add_cls=True,
  256. add_post_encode_proj=True,
  257. )
  258. return input_geometry_encoder
  259. def _create_sam3_model(
  260. backbone,
  261. transformer,
  262. input_geometry_encoder,
  263. segmentation_head,
  264. dot_prod_scoring,
  265. inst_interactive_predictor,
  266. eval_mode,
  267. ):
  268. """Create the SAM3 image model."""
  269. common_params = {
  270. "backbone": backbone,
  271. "transformer": transformer,
  272. "input_geometry_encoder": input_geometry_encoder,
  273. "segmentation_head": segmentation_head,
  274. "num_feature_levels": 1,
  275. "o2m_mask_predict": True,
  276. "dot_prod_scoring": dot_prod_scoring,
  277. "use_instance_query": False,
  278. "multimask_output": True,
  279. "inst_interactive_predictor": inst_interactive_predictor,
  280. }
  281. matcher = None
  282. if not eval_mode:
  283. from sam3.train.matcher import BinaryHungarianMatcherV2
  284. matcher = BinaryHungarianMatcherV2(
  285. focal=True,
  286. cost_class=2.0,
  287. cost_bbox=5.0,
  288. cost_giou=2.0,
  289. alpha=0.25,
  290. gamma=2,
  291. stable=False,
  292. )
  293. common_params["matcher"] = matcher
  294. model = Sam3Image(**common_params)
  295. return model
  296. def _create_tracker_maskmem_backbone():
  297. """Create the SAM3 Tracker memory encoder."""
  298. # Position encoding for mask memory backbone
  299. position_encoding = PositionEmbeddingSine(
  300. num_pos_feats=64,
  301. normalize=True,
  302. scale=None,
  303. temperature=10000,
  304. precompute_resolution=1008,
  305. )
  306. # Mask processing components
  307. mask_downsampler = SimpleMaskDownSampler(
  308. kernel_size=3, stride=2, padding=1, interpol_size=[1152, 1152]
  309. )
  310. cx_block_layer = CXBlock(
  311. dim=256,
  312. kernel_size=7,
  313. padding=3,
  314. layer_scale_init_value=1.0e-06,
  315. use_dwconv=True,
  316. )
  317. fuser = SimpleFuser(layer=cx_block_layer, num_layers=2)
  318. maskmem_backbone = SimpleMaskEncoder(
  319. out_dim=64,
  320. position_encoding=position_encoding,
  321. mask_downsampler=mask_downsampler,
  322. fuser=fuser,
  323. )
  324. return maskmem_backbone
  325. def _create_tracker_transformer():
  326. """Create the SAM3 Tracker transformer components."""
  327. # Self attention
  328. self_attention = RoPEAttention(
  329. embedding_dim=256,
  330. num_heads=1,
  331. downsample_rate=1,
  332. dropout=0.1,
  333. rope_theta=10000.0,
  334. feat_sizes=[72, 72],
  335. use_fa3=False,
  336. use_rope_real=False,
  337. )
  338. # Cross attention
  339. cross_attention = RoPEAttention(
  340. embedding_dim=256,
  341. num_heads=1,
  342. downsample_rate=1,
  343. dropout=0.1,
  344. kv_in_dim=64,
  345. rope_theta=10000.0,
  346. feat_sizes=[72, 72],
  347. rope_k_repeat=True,
  348. use_fa3=False,
  349. use_rope_real=False,
  350. )
  351. # Encoder layer
  352. encoder_layer = TransformerDecoderLayerv2(
  353. cross_attention_first=False,
  354. activation="relu",
  355. dim_feedforward=2048,
  356. dropout=0.1,
  357. pos_enc_at_attn=False,
  358. pre_norm=True,
  359. self_attention=self_attention,
  360. d_model=256,
  361. pos_enc_at_cross_attn_keys=True,
  362. pos_enc_at_cross_attn_queries=False,
  363. cross_attention=cross_attention,
  364. )
  365. # Encoder
  366. encoder = TransformerEncoderCrossAttention(
  367. remove_cross_attention_layers=[],
  368. batch_first=True,
  369. d_model=256,
  370. frozen=False,
  371. pos_enc_at_input=True,
  372. layer=encoder_layer,
  373. num_layers=4,
  374. use_act_checkpoint=False,
  375. )
  376. # Transformer wrapper
  377. transformer = TransformerWrapper(
  378. encoder=encoder,
  379. decoder=None,
  380. d_model=256,
  381. )
  382. return transformer
  383. def build_tracker(
  384. apply_temporal_disambiguation: bool, with_backbone: bool = False, compile_mode=None
  385. ) -> Sam3TrackerPredictor:
  386. """
  387. Build the SAM3 Tracker module for video tracking.
  388. Returns:
  389. Sam3TrackerPredictor: Wrapped SAM3 Tracker module
  390. """
  391. # Create model components
  392. maskmem_backbone = _create_tracker_maskmem_backbone()
  393. transformer = _create_tracker_transformer()
  394. backbone = None
  395. if with_backbone:
  396. vision_backbone = _create_vision_backbone(compile_mode=compile_mode)
  397. backbone = SAM3VLBackbone(scalp=1, visual=vision_backbone, text=None)
  398. # Create the Tracker module
  399. model = Sam3TrackerPredictor(
  400. image_size=1008,
  401. num_maskmem=7,
  402. backbone=backbone,
  403. backbone_stride=14,
  404. transformer=transformer,
  405. maskmem_backbone=maskmem_backbone,
  406. # SAM parameters
  407. multimask_output_in_sam=True,
  408. # Evaluation
  409. forward_backbone_per_frame_for_eval=True,
  410. trim_past_non_cond_mem_for_eval=False,
  411. # Multimask
  412. multimask_output_for_tracking=True,
  413. multimask_min_pt_num=0,
  414. multimask_max_pt_num=1,
  415. # Additional settings
  416. always_start_from_first_ann_frame=False,
  417. # Mask overlap
  418. non_overlap_masks_for_mem_enc=False,
  419. non_overlap_masks_for_output=False,
  420. max_cond_frames_in_attn=4,
  421. offload_output_to_cpu_for_eval=False,
  422. # SAM decoder settings
  423. sam_mask_decoder_extra_args={
  424. "dynamic_multimask_via_stability": True,
  425. "dynamic_multimask_stability_delta": 0.05,
  426. "dynamic_multimask_stability_thresh": 0.98,
  427. },
  428. clear_non_cond_mem_around_input=True,
  429. fill_hole_area=0,
  430. use_memory_selection=apply_temporal_disambiguation,
  431. )
  432. return model
  433. def _create_text_encoder(bpe_path: str) -> VETextEncoder:
  434. """Create SAM3 text encoder."""
  435. tokenizer = SimpleTokenizer(bpe_path=bpe_path)
  436. return VETextEncoder(
  437. tokenizer=tokenizer,
  438. d_model=256,
  439. width=1024,
  440. heads=16,
  441. layers=24,
  442. )
  443. def _create_vision_backbone(
  444. compile_mode=None, enable_inst_interactivity=True
  445. ) -> Sam3DualViTDetNeck:
  446. """Create SAM3 visual backbone with ViT and neck."""
  447. # Position encoding
  448. position_encoding = _create_position_encoding(precompute_resolution=1008)
  449. # ViT backbone
  450. vit_backbone: ViT = _create_vit_backbone(compile_mode=compile_mode)
  451. vit_neck: Sam3DualViTDetNeck = _create_vit_neck(
  452. position_encoding,
  453. vit_backbone,
  454. enable_inst_interactivity=enable_inst_interactivity,
  455. )
  456. # Visual neck
  457. return vit_neck
  458. def _create_sam3_transformer(has_presence_token: bool = True) -> TransformerWrapper:
  459. """Create SAM3 transformer encoder and decoder."""
  460. encoder: TransformerEncoderFusion = _create_transformer_encoder()
  461. decoder: TransformerDecoder = _create_transformer_decoder()
  462. return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
  463. def _load_checkpoint(model, checkpoint_path):
  464. """Load model checkpoint from file."""
  465. with g_pathmgr.open(checkpoint_path, "rb") as f:
  466. ckpt = torch.load(f, map_location="cpu", weights_only=True)
  467. if "model" in ckpt and isinstance(ckpt["model"], dict):
  468. ckpt = ckpt["model"]
  469. sam3_image_ckpt = {
  470. k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k
  471. }
  472. if model.inst_interactive_predictor is not None:
  473. sam3_image_ckpt.update(
  474. {
  475. k.replace("tracker.", "inst_interactive_predictor.model."): v
  476. for k, v in ckpt.items()
  477. if "tracker" in k
  478. }
  479. )
  480. missing_keys, _ = model.load_state_dict(sam3_image_ckpt, strict=False)
  481. if len(missing_keys) > 0:
  482. print(
  483. f"loaded {checkpoint_path} and found "
  484. f"missing and/or unexpected keys:\n{missing_keys=}"
  485. )
  486. def _setup_device_and_mode(model, device, eval_mode):
  487. """Setup model device and evaluation mode."""
  488. if device == "cuda":
  489. model = model.cuda()
  490. if eval_mode:
  491. model.eval()
  492. return model
  493. def build_sam3_image_model(
  494. bpe_path=None,
  495. device="cuda" if torch.cuda.is_available() else "cpu",
  496. eval_mode=True,
  497. checkpoint_path=None,
  498. load_from_HF=True,
  499. enable_segmentation=True,
  500. enable_inst_interactivity=False,
  501. compile=False,
  502. ):
  503. """
  504. Build SAM3 image model
  505. Args:
  506. bpe_path: Path to the BPE tokenizer vocabulary
  507. device: Device to load the model on ('cuda' or 'cpu')
  508. eval_mode: Whether to set the model to evaluation mode
  509. checkpoint_path: Optional path to model checkpoint
  510. enable_segmentation: Whether to enable segmentation head
  511. enable_inst_interactivity: Whether to enable instance interactivity (SAM 1 task)
  512. compile_mode: To enable compilation, set to "default"
  513. Returns:
  514. A SAM3 image model
  515. """
  516. if bpe_path is None:
  517. bpe_path = pkg_resources.resource_filename(
  518. "sam3", "assets/bpe_simple_vocab_16e6.txt.gz"
  519. )
  520. # Create visual components
  521. compile_mode = "default" if compile else None
  522. vision_encoder = _create_vision_backbone(
  523. compile_mode=compile_mode, enable_inst_interactivity=enable_inst_interactivity
  524. )
  525. # Create text components
  526. text_encoder = _create_text_encoder(bpe_path)
  527. # Create visual-language backbone
  528. backbone = _create_vl_backbone(vision_encoder, text_encoder)
  529. # Create transformer components
  530. transformer = _create_sam3_transformer()
  531. # Create dot product scoring
  532. dot_prod_scoring = _create_dot_product_scoring()
  533. # Create segmentation head if enabled
  534. segmentation_head = (
  535. _create_segmentation_head(compile_mode=compile_mode)
  536. if enable_segmentation
  537. else None
  538. )
  539. # Create geometry encoder
  540. input_geometry_encoder = _create_geometry_encoder()
  541. if enable_inst_interactivity:
  542. sam3_pvs_base = build_tracker(apply_temporal_disambiguation=False)
  543. inst_predictor = SAM3InteractiveImagePredictor(sam3_pvs_base)
  544. else:
  545. inst_predictor = None
  546. # Create the SAM3 model
  547. model = _create_sam3_model(
  548. backbone,
  549. transformer,
  550. input_geometry_encoder,
  551. segmentation_head,
  552. dot_prod_scoring,
  553. inst_predictor,
  554. eval_mode,
  555. )
  556. if load_from_HF and checkpoint_path is None:
  557. checkpoint_path = download_ckpt_from_hf()
  558. # Load checkpoint if provided
  559. if checkpoint_path is not None:
  560. _load_checkpoint(model, checkpoint_path)
  561. # Setup device and mode
  562. model = _setup_device_and_mode(model, device, eval_mode)
  563. return model
  564. def download_ckpt_from_hf():
  565. SAM3_MODEL_ID = "facebook/sam3"
  566. SAM3_CKPT_NAME = "sam3.pt"
  567. SAM3_CFG_NAME = "config.json"
  568. _ = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CFG_NAME)
  569. checkpoint_path = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME)
  570. return checkpoint_path
  571. def build_sam3_video_model(
  572. checkpoint_path: Optional[str] = None,
  573. load_from_HF=True,
  574. bpe_path: Optional[str] = None,
  575. has_presence_token: bool = True,
  576. geo_encoder_use_img_cross_attn: bool = True,
  577. strict_state_dict_loading: bool = True,
  578. apply_temporal_disambiguation: bool = True,
  579. device="cuda" if torch.cuda.is_available() else "cpu",
  580. compile=False,
  581. ) -> Sam3VideoInferenceWithInstanceInteractivity:
  582. """
  583. Build SAM3 dense tracking model.
  584. Args:
  585. checkpoint_path: Optional path to checkpoint file
  586. bpe_path: Path to the BPE tokenizer file
  587. Returns:
  588. Sam3VideoInferenceWithInstanceInteractivity: The instantiated dense tracking model
  589. """
  590. if bpe_path is None:
  591. bpe_path = pkg_resources.resource_filename(
  592. "sam3", "assets/bpe_simple_vocab_16e6.txt.gz"
  593. )
  594. # Build Tracker module
  595. tracker = build_tracker(apply_temporal_disambiguation=apply_temporal_disambiguation)
  596. # Build Detector components
  597. visual_neck = _create_vision_backbone()
  598. text_encoder = _create_text_encoder(bpe_path)
  599. backbone = SAM3VLBackbone(scalp=1, visual=visual_neck, text=text_encoder)
  600. transformer = _create_sam3_transformer(has_presence_token=has_presence_token)
  601. segmentation_head: UniversalSegmentationHead = _create_segmentation_head()
  602. input_geometry_encoder = _create_geometry_encoder()
  603. # Create main dot product scoring
  604. main_dot_prod_mlp = MLP(
  605. input_dim=256,
  606. hidden_dim=2048,
  607. output_dim=256,
  608. num_layers=2,
  609. dropout=0.1,
  610. residual=True,
  611. out_norm=nn.LayerNorm(256),
  612. )
  613. main_dot_prod_scoring = DotProductScoring(
  614. d_model=256, d_proj=256, prompt_mlp=main_dot_prod_mlp
  615. )
  616. # Build Detector module
  617. detector = Sam3ImageOnVideoMultiGPU(
  618. num_feature_levels=1,
  619. backbone=backbone,
  620. transformer=transformer,
  621. segmentation_head=segmentation_head,
  622. semantic_segmentation_head=None,
  623. input_geometry_encoder=input_geometry_encoder,
  624. use_early_fusion=True,
  625. use_dot_prod_scoring=True,
  626. dot_prod_scoring=main_dot_prod_scoring,
  627. supervise_joint_box_scores=has_presence_token,
  628. )
  629. # Build the main SAM3 video model
  630. if apply_temporal_disambiguation:
  631. model = Sam3VideoInferenceWithInstanceInteractivity(
  632. detector=detector,
  633. tracker=tracker,
  634. score_threshold_detection=0.5,
  635. assoc_iou_thresh=0.1,
  636. det_nms_thresh=0.1,
  637. new_det_thresh=0.7,
  638. hotstart_delay=15,
  639. hotstart_unmatch_thresh=8,
  640. hotstart_dup_thresh=8,
  641. suppress_unmatched_only_within_hotstart=True,
  642. min_trk_keep_alive=-1,
  643. max_trk_keep_alive=30,
  644. init_trk_keep_alive=30,
  645. suppress_overlapping_based_on_recent_occlusion_threshold=0.7,
  646. suppress_det_close_to_boundary=False,
  647. fill_hole_area=16,
  648. recondition_every_nth_frame=16,
  649. masklet_confirmation_enable=False,
  650. decrease_trk_keep_alive_for_empty_masklets=False,
  651. image_size=1008,
  652. image_mean=(0.5, 0.5, 0.5),
  653. image_std=(0.5, 0.5, 0.5),
  654. compile_model=compile,
  655. )
  656. else:
  657. # a version without any heuristics for ablation studies
  658. model = Sam3VideoInferenceWithInstanceInteractivity(
  659. detector=detector,
  660. tracker=tracker,
  661. score_threshold_detection=0.5,
  662. assoc_iou_thresh=0.1,
  663. det_nms_thresh=0.1,
  664. new_det_thresh=0.7,
  665. hotstart_delay=0,
  666. hotstart_unmatch_thresh=0,
  667. hotstart_dup_thresh=0,
  668. suppress_unmatched_only_within_hotstart=True,
  669. min_trk_keep_alive=-1,
  670. max_trk_keep_alive=30,
  671. init_trk_keep_alive=30,
  672. suppress_overlapping_based_on_recent_occlusion_threshold=0.7,
  673. suppress_det_close_to_boundary=False,
  674. fill_hole_area=16,
  675. recondition_every_nth_frame=0,
  676. masklet_confirmation_enable=False,
  677. decrease_trk_keep_alive_for_empty_masklets=False,
  678. image_size=1008,
  679. image_mean=(0.5, 0.5, 0.5),
  680. image_std=(0.5, 0.5, 0.5),
  681. compile_model=compile,
  682. )
  683. # Load checkpoint if provided
  684. if load_from_HF and checkpoint_path is None:
  685. checkpoint_path = download_ckpt_from_hf()
  686. if checkpoint_path is not None:
  687. with g_pathmgr.open(checkpoint_path, "rb") as f:
  688. ckpt = torch.load(f, map_location="cpu", weights_only=True)
  689. if "model" in ckpt and isinstance(ckpt["model"], dict):
  690. ckpt = ckpt["model"]
  691. missing_keys, unexpected_keys = model.load_state_dict(
  692. ckpt, strict=strict_state_dict_loading
  693. )
  694. if missing_keys:
  695. print(f"Missing keys: {missing_keys}")
  696. if unexpected_keys:
  697. print(f"Unexpected keys: {unexpected_keys}")
  698. model.to(device=device)
  699. return model
  700. def build_sam3_video_predictor(*model_args, gpus_to_use=None, **model_kwargs):
  701. return Sam3VideoPredictorMultiGPU(
  702. *model_args, gpus_to_use=gpus_to_use, **model_kwargs
  703. )