import warnings warnings.filterwarnings("ignore") from fastapi import APIRouter, HTTPException, UploadFile, File from fastapi.responses import JSONResponse from .models import ( SegmentationRequest, # SegmentationWithBoxRequest, # SegmentationResponse, # PointCoords ) # from segment_anything_hq import sam_model_registry, SamPredictor from typing import List import torch from PIL import Image import requests import numpy as np import base64 import io,os from datetime import datetime # from sam2.build_sam import build_sam2 # from sam2.sam2_image_predictor import SAM2ImagePredictor from utils.sam_utils import show_masks,convert_to_serializable import sam3 from sam3 import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor router = APIRouter() # 初始化模型(应用启动时加载一次) device = "cuda" if torch.cuda.is_available() else "cpu" torch.autocast("cuda", dtype=torch.bfloat16).__enter__() # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # 配置PyTorch设置 if torch.cuda.is_available(): # 设置默认dtype为float32以避免bfloat16问题 torch.set_default_tensor_type('torch.cuda.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') # 关闭autocast以避免bfloat16问题 torch.set_autocast_enabled(False) # sam1模型配置 # sam_checkpoint = "./models/sam_hq_vit_h.pth" # model_type = "vit_h" # sam1_model = sam_model_registry[model_type](checkpoint=sam_checkpoint) # sam1_model = sam1_model.float() # 强制转换为float32 # sam1_model.to(device=device) # # sam2模型配置 # sam2_checkpoint = "/root/.cache/modelscope/hub/models/facebook/sam2.1-hiera-large/sam2.1_hiera_large.pt" # model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" # sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) # sam3 sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..") bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz" sam3_checkpoint = "/root/.cache/modelscope/hub/models/facebook/sam3/sam3.pt" sam3_model = build_sam3_image_model(device=device,bpe_path=bpe_path, enable_inst_interactivity=True,load_from_HF=False,checkpoint_path=sam3_checkpoint) print("device", device) @router.post("/segment_with_points") async def segment_with_points(request: SegmentationRequest): image_url = request.image_url raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") # input_point = np.array([[296, 543], [150, 543], [296, 683], [150, 683]]) # input_label = np.array([1, 0, 1, 1]) input_point = np.array(request.points) input_label = np.array(request.labels) # match request.type: # case 0: # predictor = SamPredictor(sam1_model) # raw_image = np.array(raw_image) # predictor.set_image(raw_image) # with torch.cuda.amp.autocast(enabled=False): # 临时禁用autocast # masks, scores, logits = predictor.predict( # point_coords=input_point, # point_labels=input_label, # multimask_output=True, # hq_token_only=True, # ) # mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask # masks, scores, _ = predictor.predict( # point_coords=input_point, # point_labels=input_label, # mask_input=mask_input[None, :, :], # multimask_output=False, # hq_token_only=True, # ) # case 1: # # sam2 # predictor = SAM2ImagePredictor(sam2_model) # predictor.set_image(raw_image) # masks, scores, logits = predictor.predict( # point_coords=input_point, # point_labels=input_label, # multimask_output=True, # ) # mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask # masks, scores, _ = predictor.predict( # point_coords=input_point, # point_labels=input_label, # mask_input=mask_input[None, :, :], # multimask_output=False, # ) # case 2: # # sam3 # processor = Sam3Processor(sam3_model,device=device) # inference_state = processor.set_image(raw_image) # masks, scores, logits = sam3_model.predict_inst( # inference_state, # point_coords=input_point, # point_labels=input_label, # multimask_output=True, # ) # mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask # masks, scores, logits = sam3_model.predict_inst( # inference_state, # point_coords=input_point, # point_labels=input_label, # mask_input=mask_input[None, :, :], # multimask_output=False, # ) processor = Sam3Processor(sam3_model,device=device) inference_state = processor.set_image(raw_image) masks, scores, logits = sam3_model.predict_inst( inference_state, point_coords=input_point, point_labels=input_label, multimask_output=True, ) mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask masks, scores, logits = sam3_model.predict_inst( inference_state, point_coords=input_point, point_labels=input_label, mask_input=mask_input[None, :, :], multimask_output=False, ) baseData = show_masks(masks,scores) # 转换为可序列化的格式 serializable_result = { "code": 0, "data": { "base_64_data": convert_to_serializable(baseData) } } return serializable_result @router.get("/") async def root(): return {"message": "Welcome to SAM2 Segmentation API"}