import warnings warnings.filterwarnings("ignore") import numpy as np import torch from segment_anything_hq import sam_model_registry, SamPredictor import os,requests from PIL import Image def show_mask(mask): color = np.array([255, 255, 255, 1]) h, w = mask.shape[-2:] mask = mask.astype(np.uint8) mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) # 将值域限制在 [0, 1],然后转换为 [0, 255] 的 uint8 类型 mask_image = np.clip(mask_image, 0, 1) # 确保值在 [0, 1] 范围内 mask_image_uint8 = (mask_image * 255).astype(np.uint8) return mask_image_uint8 def show_masks(masks, scores): for i, (mask, score) in enumerate(zip(masks, scores)): np_arr = show_mask(mask) arr_img = Image.fromarray(np_arr) arr_img.save(f"sam_1_ax_{i}.png") sam_checkpoint = "./models/sam_hq_vit_h.pth" model_type = "vit_h" device = "cuda" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) predictor = SamPredictor(sam) image_url = "https://ossimg.valimart.net/uploads/vali_ai/20260129/176968033473923.png" raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") image_array = np.array(raw_image) predictor.set_image(image_array) input_point = np.array([[296, 543], [150, 543], [296, 683], [150, 683]]) input_label = np.array([1, 0, 1, 1]) hq_token_only = True masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, hq_token_only=hq_token_only, ) mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask masks, scores, _ = predictor.predict( point_coords=input_point, point_labels=input_label, mask_input=mask_input[None, :, :], multimask_output=False, hq_token_only=hq_token_only, ) # masks = masks.squeeze(1).cpu().numpy() show_masks(masks, scores)