| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import warnings
- warnings.filterwarnings("ignore")
- import numpy as np
- import torch
- from segment_anything_hq import sam_model_registry, SamPredictor
- import os,requests
- from PIL import Image
- def show_mask(mask):
- color = np.array([255, 255, 255, 1])
-
- h, w = mask.shape[-2:]
- mask = mask.astype(np.uint8)
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
- # 将值域限制在 [0, 1],然后转换为 [0, 255] 的 uint8 类型
- mask_image = np.clip(mask_image, 0, 1) # 确保值在 [0, 1] 范围内
- mask_image_uint8 = (mask_image * 255).astype(np.uint8)
- return mask_image_uint8
- def show_masks(masks, scores):
- for i, (mask, score) in enumerate(zip(masks, scores)):
- np_arr = show_mask(mask)
- arr_img = Image.fromarray(np_arr)
- arr_img.save(f"sam_1_ax_{i}.png")
- sam_checkpoint = "./models/sam_hq_vit_h.pth"
- model_type = "vit_h"
- device = "cuda"
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
- sam.to(device=device)
- predictor = SamPredictor(sam)
- image_url = "https://ossimg.valimart.net/uploads/vali_ai/20260129/176968033473923.png"
- raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
- image_array = np.array(raw_image)
- predictor.set_image(image_array)
- input_point = np.array([[296, 543], [150, 543], [296, 683], [150, 683]])
- input_label = np.array([1, 0, 1, 1])
- hq_token_only = True
- masks, scores, logits = predictor.predict(
- point_coords=input_point,
- point_labels=input_label,
- multimask_output=True,
- hq_token_only=hq_token_only,
- )
- mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
- masks, scores, _ = predictor.predict(
- point_coords=input_point,
- point_labels=input_label,
- mask_input=mask_input[None, :, :],
- multimask_output=False,
- hq_token_only=hq_token_only,
- )
- # masks = masks.squeeze(1).cpu().numpy()
- show_masks(masks, scores)
|