| 12345678910111213141516171819202122232425262728293031323334353637 |
- import torch
- from PIL import Image
- import requests
- from datetime import datetime
- import numpy as np
- from sam2.build_sam import build_sam2
- from sam2.sam2_image_predictor import SAM2ImagePredictor
- from utils.sam_utils import show_masks
- sam2_checkpoint = "/root/.cache/modelscope/hub/models/facebook/sam2.1-hiera-large/sam2.1_hiera_large.pt"
- model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
- device = "cuda" if torch.cuda.is_available() else "cpu"
- print("device",device)
- sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
- predictor = SAM2ImagePredictor(sam2_model)
- image_url = "https://ossimg.valimart.net/uploads/vali_ai/20260129/176968033473923.png"
- raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
- predictor.set_image(raw_image)
- start_time = datetime.now()
- input_point = [[296, 543], [150, 543], [296, 683], [150, 683]]
- input_label = [1, 0, 1, 1]
- masks, scores, logits = predictor.predict(
- point_coords=input_point,
- point_labels=input_label,
- multimask_output=True,
- )
- mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
- masks, scores, _ = predictor.predict(
- point_coords=input_point,
- point_labels=input_label,
- mask_input=mask_input[None, :, :],
- multimask_output=False,
- )
- baseData = show_masks(masks,scores)
- end = datetime.now()
- time_diff = end - start_time
- print(f"执行时间: {time_diff.total_seconds()} 秒")
|