| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- import os
- # if using Apple MPS, fall back to CPU for unsupported ops
- # os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
- import numpy as np
- import torch
- import matplotlib.pyplot as plt
- from PIL import Image
- import sam3
- import os,requests
- from sam3 import build_sam3_image_model
- from sam3.model.sam3_image_processor import Sam3Processor
- from datetime import datetime
- # 指定使用第0块GPU
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
- sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
- # select the device for computation
- if torch.cuda.is_available():
- device = torch.device("cuda")
- elif torch.backends.mps.is_available():
- device = torch.device("mps")
- else:
- device = torch.device("cpu")
- print(f"using device: {device}")
- torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
- if torch.cuda.get_device_properties(0).major >= 8:
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
- np.random.seed(3)
- def show_mask(mask):
- color = np.array([255, 255, 255, 1])
-
- h, w = mask.shape[-2:]
- mask = mask.astype(np.uint8)
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
- # 将值域限制在 [0, 1],然后转换为 [0, 255] 的 uint8 类型
- mask_image = np.clip(mask_image, 0, 1) # 确保值在 [0, 1] 范围内
- mask_image_uint8 = (mask_image * 255).astype(np.uint8)
- return mask_image_uint8
- def show_masks(image, masks, scores):
- for i, (mask, score) in enumerate(zip(masks, scores)):
- np_arr = show_mask(mask)
- arr_img = Image.fromarray(np_arr)
- arr_img.save(f"sam3_ax_{i}.png")
-
- image_url = "https://ossimg.valimart.net/uploads/vali_ai/20260129/176968033473923.png"
- raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
- bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
- checkpoint_path = "/root/.cache/modelscope/hub/models/facebook/sam3/sam3.pt"
- model = build_sam3_image_model(bpe_path=bpe_path, enable_inst_interactivity=True,load_from_HF=False,checkpoint_path=checkpoint_path)
- start_time = datetime.now()
- processor = Sam3Processor(model)
- inference_state = processor.set_image(raw_image)
- input_point = [[296, 543], [150, 543], [296, 683], [150, 683]]
- input_label = [1, 0, 1, 1]
- masks, scores, logits = model.predict_inst(
- inference_state,
- point_coords=input_point,
- point_labels=input_label,
- multimask_output=True,
- )
- mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
- masks, scores, logits = model.predict_inst(
- inference_state,
- point_coords=input_point,
- point_labels=input_label,
- mask_input=mask_input[None, :, :],
- multimask_output=False,
- )
- show_masks(raw_image, masks, scores)
- end = datetime.now()
- time_diff = end - start_time
- print(f"执行时间: {time_diff.total_seconds()} 秒")
|