import os # if using Apple MPS, fall back to CPU for unsupported ops # os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import numpy as np import torch import matplotlib.pyplot as plt from PIL import Image import sam3 import os,requests from sam3 import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor from datetime import datetime # 指定使用第0块GPU os.environ["CUDA_VISIBLE_DEVICES"] = "0" sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..") # select the device for computation if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") print(f"using device: {device}") torch.autocast("cuda", dtype=torch.bfloat16).__enter__() # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True np.random.seed(3) def show_mask(mask): color = np.array([255, 255, 255, 1]) h, w = mask.shape[-2:] mask = mask.astype(np.uint8) mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) # 将值域限制在 [0, 1],然后转换为 [0, 255] 的 uint8 类型 mask_image = np.clip(mask_image, 0, 1) # 确保值在 [0, 1] 范围内 mask_image_uint8 = (mask_image * 255).astype(np.uint8) return mask_image_uint8 def show_masks(image, masks, scores): for i, (mask, score) in enumerate(zip(masks, scores)): np_arr = show_mask(mask) arr_img = Image.fromarray(np_arr) arr_img.save(f"sam3_ax_{i}.png") image_url = "https://ossimg.valimart.net/uploads/vali_ai/20260129/176968033473923.png" raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz" checkpoint_path = "/root/.cache/modelscope/hub/models/facebook/sam3/sam3.pt" model = build_sam3_image_model(bpe_path=bpe_path, enable_inst_interactivity=True,load_from_HF=False,checkpoint_path=checkpoint_path) start_time = datetime.now() processor = Sam3Processor(model) inference_state = processor.set_image(raw_image) input_point = [[296, 543], [150, 543], [296, 683], [150, 683]] input_label = [1, 0, 1, 1] masks, scores, logits = model.predict_inst( inference_state, point_coords=input_point, point_labels=input_label, multimask_output=True, ) mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask masks, scores, logits = model.predict_inst( inference_state, point_coords=input_point, point_labels=input_label, mask_input=mask_input[None, :, :], multimask_output=False, ) show_masks(raw_image, masks, scores) end = datetime.now() time_diff = end - start_time print(f"执行时间: {time_diff.total_seconds()} 秒")