sam1_demo.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import warnings
  2. warnings.filterwarnings("ignore")
  3. import numpy as np
  4. import torch
  5. from segment_anything_hq import sam_model_registry, SamPredictor
  6. import os,requests
  7. from PIL import Image
  8. def show_mask(mask):
  9. color = np.array([255, 255, 255, 1])
  10. h, w = mask.shape[-2:]
  11. mask = mask.astype(np.uint8)
  12. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  13. # 将值域限制在 [0, 1],然后转换为 [0, 255] 的 uint8 类型
  14. mask_image = np.clip(mask_image, 0, 1) # 确保值在 [0, 1] 范围内
  15. mask_image_uint8 = (mask_image * 255).astype(np.uint8)
  16. return mask_image_uint8
  17. def show_masks(masks, scores):
  18. for i, (mask, score) in enumerate(zip(masks, scores)):
  19. np_arr = show_mask(mask)
  20. arr_img = Image.fromarray(np_arr)
  21. arr_img.save(f"sam_1_ax_{i}.png")
  22. sam_checkpoint = "./models/sam_hq_vit_h.pth"
  23. model_type = "vit_h"
  24. device = "cuda"
  25. sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
  26. sam.to(device=device)
  27. predictor = SamPredictor(sam)
  28. image_url = "https://ossimg.valimart.net/uploads/vali_ai/20260129/176968033473923.png"
  29. raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
  30. image_array = np.array(raw_image)
  31. predictor.set_image(image_array)
  32. input_point = np.array([[296, 543], [150, 543], [296, 683], [150, 683]])
  33. input_label = np.array([1, 0, 1, 1])
  34. hq_token_only = True
  35. masks, scores, logits = predictor.predict(
  36. point_coords=input_point,
  37. point_labels=input_label,
  38. multimask_output=True,
  39. hq_token_only=hq_token_only,
  40. )
  41. mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
  42. masks, scores, _ = predictor.predict(
  43. point_coords=input_point,
  44. point_labels=input_label,
  45. mask_input=mask_input[None, :, :],
  46. multimask_output=False,
  47. hq_token_only=hq_token_only,
  48. )
  49. # masks = masks.squeeze(1).cpu().numpy()
  50. show_masks(masks, scores)