benchmark.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import os
  6. import time
  7. import numpy as np
  8. import torch
  9. from tqdm import tqdm
  10. from sam2.build_sam import build_sam2_video_predictor
  11. # Only cuda supported
  12. assert torch.cuda.is_available()
  13. device = torch.device("cuda")
  14. torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
  15. if torch.cuda.get_device_properties(0).major >= 8:
  16. # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
  17. torch.backends.cuda.matmul.allow_tf32 = True
  18. torch.backends.cudnn.allow_tf32 = True
  19. # Config and checkpoint
  20. sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
  21. model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
  22. # Build video predictor with vos_optimized=True setting
  23. predictor = build_sam2_video_predictor(
  24. model_cfg, sam2_checkpoint, device=device, vos_optimized=True
  25. )
  26. # Initialize with video
  27. video_dir = "notebooks/videos/bedroom"
  28. # scan all the JPEG frame names in this directory
  29. frame_names = [
  30. p
  31. for p in os.listdir(video_dir)
  32. if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
  33. ]
  34. frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
  35. inference_state = predictor.init_state(video_path=video_dir)
  36. # Number of runs, warmup etc
  37. warm_up, runs = 5, 25
  38. verbose = True
  39. num_frames = len(frame_names)
  40. total, count = 0, 0
  41. torch.cuda.empty_cache()
  42. # We will select an object with a click.
  43. # See video_predictor_example.ipynb for more detailed explanation
  44. ann_frame_idx, ann_obj_id = 0, 1
  45. # Add a positive click at (x, y) = (210, 350)
  46. # For labels, `1` means positive click
  47. points = np.array([[210, 350]], dtype=np.float32)
  48. labels = np.array([1], np.int32)
  49. _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
  50. inference_state=inference_state,
  51. frame_idx=ann_frame_idx,
  52. obj_id=ann_obj_id,
  53. points=points,
  54. labels=labels,
  55. )
  56. # Warmup and then average FPS over several runs
  57. with torch.autocast("cuda", torch.bfloat16):
  58. with torch.inference_mode():
  59. for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
  60. start = time.time()
  61. # Start tracking
  62. for (
  63. out_frame_idx,
  64. out_obj_ids,
  65. out_mask_logits,
  66. ) in predictor.propagate_in_video(inference_state):
  67. pass
  68. end = time.time()
  69. total += end - start
  70. count += 1
  71. if i == warm_up - 1:
  72. print("Warmup FPS: ", count * num_frames / total)
  73. total = 0
  74. count = 0
  75. print("FPS: ", count * num_frames / total)