| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- import torch
- from hydra import compose
- from hydra.utils import instantiate
- from omegaconf import OmegaConf
- from huggingface_hub import hf_hub_download
- def build_sam2(
- config_file,
- ckpt_path=None,
- device="cuda",
- mode="eval",
- hydra_overrides_extra=[],
- apply_postprocessing=True,
- ):
- if apply_postprocessing:
- hydra_overrides_extra = hydra_overrides_extra.copy()
- hydra_overrides_extra += [
- # dynamically fall back to multi-mask if the single mask is not stable
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
- ]
- # Read config and init model
- cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
- OmegaConf.resolve(cfg)
- model = instantiate(cfg.model, _recursive_=True)
- _load_checkpoint(model, ckpt_path)
- model = model.to(device)
- if mode == "eval":
- model.eval()
- return model
- def build_sam2_video_predictor(
- config_file,
- ckpt_path=None,
- device="cuda",
- mode="eval",
- hydra_overrides_extra=[],
- apply_postprocessing=True,
- ):
- hydra_overrides = [
- "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
- ]
- if apply_postprocessing:
- hydra_overrides_extra = hydra_overrides_extra.copy()
- hydra_overrides_extra += [
- # dynamically fall back to multi-mask if the single mask is not stable
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
- # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
- "++model.binarize_mask_from_pts_for_mem_enc=true",
- # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
- "++model.fill_hole_area=8",
- ]
- hydra_overrides.extend(hydra_overrides_extra)
- # Read config and init model
- cfg = compose(config_name=config_file, overrides=hydra_overrides)
- OmegaConf.resolve(cfg)
- model = instantiate(cfg.model, _recursive_=True)
- _load_checkpoint(model, ckpt_path)
- model = model.to(device)
- if mode == "eval":
- model.eval()
- return model
- def build_sam2_hf(model_id, **kwargs):
- model_id_to_filenames = {
- "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
- "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
- "facebook/sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt"),
- "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
- }
- config_name, checkpoint_name = model_id_to_filenames[model_id]
- config_file = hf_hub_download(repo_id=model_id, filename=config_name)
- ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
- return build_sam2_video_predictor(config_file=config_file, ckpt_path=ckpt_path, **kwargs)
- def build_sam2_video_predictor_hf(model_id, **kwargs):
- config_file = hf_hub_download(repo_id=model_id, filename=f"{model_id}.yaml")
- ckpt_path = hf_hub_download(repo_id=model_id, filename=f"{model_id}.pt")
- return build_sam2_video_predictor(config_file=config_file, ckpt_path=ckpt_path, **kwargs)
- def _load_checkpoint(model, ckpt_path):
- if ckpt_path is not None:
- sd = torch.load(ckpt_path, map_location="cpu")["model"]
- missing_keys, unexpected_keys = model.load_state_dict(sd)
- if missing_keys:
- logging.error(missing_keys)
- raise RuntimeError()
- if unexpected_keys:
- logging.error(unexpected_keys)
- raise RuntimeError()
- logging.info("Loaded checkpoint sucessfully")
|