sam2_demo.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import torch
  2. from PIL import Image
  3. import requests
  4. from datetime import datetime
  5. import numpy as np
  6. from sam2.build_sam import build_sam2
  7. from sam2.sam2_image_predictor import SAM2ImagePredictor
  8. from utils.sam_utils import show_masks
  9. sam2_checkpoint = "/root/.cache/modelscope/hub/models/facebook/sam2.1-hiera-large/sam2.1_hiera_large.pt"
  10. model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
  11. device = "cuda" if torch.cuda.is_available() else "cpu"
  12. print("device",device)
  13. sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
  14. predictor = SAM2ImagePredictor(sam2_model)
  15. image_url = "https://ossimg.valimart.net/uploads/vali_ai/20260129/176968033473923.png"
  16. raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
  17. predictor.set_image(raw_image)
  18. start_time = datetime.now()
  19. input_point = [[296, 543], [150, 543], [296, 683], [150, 683]]
  20. input_label = [1, 0, 1, 1]
  21. masks, scores, logits = predictor.predict(
  22. point_coords=input_point,
  23. point_labels=input_label,
  24. multimask_output=True,
  25. )
  26. mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
  27. masks, scores, _ = predictor.predict(
  28. point_coords=input_point,
  29. point_labels=input_label,
  30. mask_input=mask_input[None, :, :],
  31. multimask_output=False,
  32. )
  33. baseData = show_masks(masks,scores)
  34. end = datetime.now()
  35. time_diff = end - start_time
  36. print(f"执行时间: {time_diff.total_seconds()} 秒")