| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import os
- import time
- import numpy as np
- import torch
- from tqdm import tqdm
- from sam2.build_sam import build_sam2_video_predictor
- # Only cuda supported
- assert torch.cuda.is_available()
- device = torch.device("cuda")
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
- if torch.cuda.get_device_properties(0).major >= 8:
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
- # Config and checkpoint
- sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
- model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
- # Build video predictor with vos_optimized=True setting
- predictor = build_sam2_video_predictor(
- model_cfg, sam2_checkpoint, device=device, vos_optimized=True
- )
- # Initialize with video
- video_dir = "notebooks/videos/bedroom"
- # scan all the JPEG frame names in this directory
- frame_names = [
- p
- for p in os.listdir(video_dir)
- if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
- ]
- frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
- inference_state = predictor.init_state(video_path=video_dir)
- # Number of runs, warmup etc
- warm_up, runs = 5, 25
- verbose = True
- num_frames = len(frame_names)
- total, count = 0, 0
- torch.cuda.empty_cache()
- # We will select an object with a click.
- # See video_predictor_example.ipynb for more detailed explanation
- ann_frame_idx, ann_obj_id = 0, 1
- # Add a positive click at (x, y) = (210, 350)
- # For labels, `1` means positive click
- points = np.array([[210, 350]], dtype=np.float32)
- labels = np.array([1], np.int32)
- _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
- inference_state=inference_state,
- frame_idx=ann_frame_idx,
- obj_id=ann_obj_id,
- points=points,
- labels=labels,
- )
- # Warmup and then average FPS over several runs
- with torch.autocast("cuda", torch.bfloat16):
- with torch.inference_mode():
- for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
- start = time.time()
- # Start tracking
- for (
- out_frame_idx,
- out_obj_ids,
- out_mask_logits,
- ) in predictor.propagate_in_video(inference_state):
- pass
- end = time.time()
- total += end - start
- count += 1
- if i == warm_up - 1:
- print("Warmup FPS: ", count * num_frames / total)
- total = 0
- count = 0
- print("FPS: ", count * num_frames / total)
|