sam3_demo.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import os
  2. # if using Apple MPS, fall back to CPU for unsupported ops
  3. # os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
  4. import numpy as np
  5. import torch
  6. import matplotlib.pyplot as plt
  7. from PIL import Image
  8. import sam3
  9. import os,requests
  10. from sam3 import build_sam3_image_model
  11. from sam3.model.sam3_image_processor import Sam3Processor
  12. from datetime import datetime
  13. # 指定使用第0块GPU
  14. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  15. sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
  16. # select the device for computation
  17. if torch.cuda.is_available():
  18. device = torch.device("cuda")
  19. elif torch.backends.mps.is_available():
  20. device = torch.device("mps")
  21. else:
  22. device = torch.device("cpu")
  23. print(f"using device: {device}")
  24. torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
  25. # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
  26. if torch.cuda.get_device_properties(0).major >= 8:
  27. torch.backends.cuda.matmul.allow_tf32 = True
  28. torch.backends.cudnn.allow_tf32 = True
  29. np.random.seed(3)
  30. def show_mask(mask):
  31. color = np.array([255, 255, 255, 1])
  32. h, w = mask.shape[-2:]
  33. mask = mask.astype(np.uint8)
  34. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  35. # 将值域限制在 [0, 1],然后转换为 [0, 255] 的 uint8 类型
  36. mask_image = np.clip(mask_image, 0, 1) # 确保值在 [0, 1] 范围内
  37. mask_image_uint8 = (mask_image * 255).astype(np.uint8)
  38. return mask_image_uint8
  39. def show_masks(image, masks, scores):
  40. for i, (mask, score) in enumerate(zip(masks, scores)):
  41. np_arr = show_mask(mask)
  42. arr_img = Image.fromarray(np_arr)
  43. arr_img.save(f"sam3_ax_{i}.png")
  44. image_url = "https://ossimg.valimart.net/uploads/vali_ai/20260129/176968033473923.png"
  45. raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
  46. bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
  47. checkpoint_path = "/root/.cache/modelscope/hub/models/facebook/sam3/sam3.pt"
  48. model = build_sam3_image_model(bpe_path=bpe_path, enable_inst_interactivity=True,load_from_HF=False,checkpoint_path=checkpoint_path)
  49. start_time = datetime.now()
  50. processor = Sam3Processor(model)
  51. inference_state = processor.set_image(raw_image)
  52. input_point = [[296, 543], [150, 543], [296, 683], [150, 683]]
  53. input_label = [1, 0, 1, 1]
  54. masks, scores, logits = model.predict_inst(
  55. inference_state,
  56. point_coords=input_point,
  57. point_labels=input_label,
  58. multimask_output=True,
  59. )
  60. mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
  61. masks, scores, logits = model.predict_inst(
  62. inference_state,
  63. point_coords=input_point,
  64. point_labels=input_label,
  65. mask_input=mask_input[None, :, :],
  66. multimask_output=False,
  67. )
  68. show_masks(raw_image, masks, scores)
  69. end = datetime.now()
  70. time_diff = end - start_time
  71. print(f"执行时间: {time_diff.total_seconds()} 秒")