import torch from PIL import Image import requests from datetime import datetime import numpy as np from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from utils.sam_utils import show_masks sam2_checkpoint = "/root/.cache/modelscope/hub/models/facebook/sam2.1-hiera-large/sam2.1_hiera_large.pt" model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" device = "cuda" if torch.cuda.is_available() else "cpu" print("device",device) sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) predictor = SAM2ImagePredictor(sam2_model) image_url = "https://ossimg.valimart.net/uploads/vali_ai/20260129/176968033473923.png" raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") predictor.set_image(raw_image) start_time = datetime.now() input_point = [[296, 543], [150, 543], [296, 683], [150, 683]] input_label = [1, 0, 1, 1] masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, ) mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask masks, scores, _ = predictor.predict( point_coords=input_point, point_labels=input_label, mask_input=mask_input[None, :, :], multimask_output=False, ) baseData = show_masks(masks,scores) end = datetime.now() time_diff = end - start_time print(f"执行时间: {time_diff.total_seconds()} 秒")