| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- import warnings
- warnings.filterwarnings("ignore")
- from fastapi import APIRouter, HTTPException, UploadFile, File
- from fastapi.responses import JSONResponse
- from .models import (
- SegmentationRequest,
- # SegmentationWithBoxRequest,
- # SegmentationResponse,
- # PointCoords
- )
- # from segment_anything_hq import sam_model_registry, SamPredictor
- from typing import List
- import torch
- from PIL import Image
- import requests
- import numpy as np
- import base64
- import io,os
- from datetime import datetime
- # from sam2.build_sam import build_sam2
- # from sam2.sam2_image_predictor import SAM2ImagePredictor
- from utils.sam_utils import show_masks,convert_to_serializable
- import sam3
- from sam3 import build_sam3_image_model
- from sam3.model.sam3_image_processor import Sam3Processor
- router = APIRouter()
- # 初始化模型(应用启动时加载一次)
- device = "cuda" if torch.cuda.is_available() else "cpu"
- torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
- if torch.cuda.get_device_properties(0).major >= 8:
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
- # 配置PyTorch设置
- if torch.cuda.is_available():
- # 设置默认dtype为float32以避免bfloat16问题
- torch.set_default_tensor_type('torch.cuda.FloatTensor')
- else:
- torch.set_default_tensor_type('torch.FloatTensor')
- # 关闭autocast以避免bfloat16问题
- torch.set_autocast_enabled(False)
- # sam1模型配置
- # sam_checkpoint = "./models/sam_hq_vit_h.pth"
- # model_type = "vit_h"
- # sam1_model = sam_model_registry[model_type](checkpoint=sam_checkpoint)
- # sam1_model = sam1_model.float() # 强制转换为float32
- # sam1_model.to(device=device)
- # # sam2模型配置
- # sam2_checkpoint = "/root/.cache/modelscope/hub/models/facebook/sam2.1-hiera-large/sam2.1_hiera_large.pt"
- # model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
- # sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
- # sam3
- sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
- bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
- sam3_checkpoint = "/root/.cache/modelscope/hub/models/facebook/sam3/sam3.pt"
- sam3_model = build_sam3_image_model(device=device,bpe_path=bpe_path, enable_inst_interactivity=True,load_from_HF=False,checkpoint_path=sam3_checkpoint)
- print("device", device)
- @router.post("/segment_with_points")
- async def segment_with_points(request: SegmentationRequest):
- image_url = request.image_url
- raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
- # input_point = np.array([[296, 543], [150, 543], [296, 683], [150, 683]])
- # input_label = np.array([1, 0, 1, 1])
- input_point = np.array(request.points)
- input_label = np.array(request.labels)
- # match request.type:
- # case 0:
- # predictor = SamPredictor(sam1_model)
- # raw_image = np.array(raw_image)
- # predictor.set_image(raw_image)
- # with torch.cuda.amp.autocast(enabled=False): # 临时禁用autocast
- # masks, scores, logits = predictor.predict(
- # point_coords=input_point,
- # point_labels=input_label,
- # multimask_output=True,
- # hq_token_only=True,
- # )
- # mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
- # masks, scores, _ = predictor.predict(
- # point_coords=input_point,
- # point_labels=input_label,
- # mask_input=mask_input[None, :, :],
- # multimask_output=False,
- # hq_token_only=True,
- # )
- # case 1:
- # # sam2
- # predictor = SAM2ImagePredictor(sam2_model)
- # predictor.set_image(raw_image)
- # masks, scores, logits = predictor.predict(
- # point_coords=input_point,
- # point_labels=input_label,
- # multimask_output=True,
- # )
- # mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
- # masks, scores, _ = predictor.predict(
- # point_coords=input_point,
- # point_labels=input_label,
- # mask_input=mask_input[None, :, :],
- # multimask_output=False,
- # )
- # case 2:
- # # sam3
- # processor = Sam3Processor(sam3_model,device=device)
- # inference_state = processor.set_image(raw_image)
- # masks, scores, logits = sam3_model.predict_inst(
- # inference_state,
- # point_coords=input_point,
- # point_labels=input_label,
- # multimask_output=True,
- # )
- # mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
- # masks, scores, logits = sam3_model.predict_inst(
- # inference_state,
- # point_coords=input_point,
- # point_labels=input_label,
- # mask_input=mask_input[None, :, :],
- # multimask_output=False,
- # )
- processor = Sam3Processor(sam3_model,device=device)
- inference_state = processor.set_image(raw_image)
- masks, scores, logits = sam3_model.predict_inst(
- inference_state,
- point_coords=input_point,
- point_labels=input_label,
- multimask_output=True,
- )
- mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
- masks, scores, logits = sam3_model.predict_inst(
- inference_state,
- point_coords=input_point,
- point_labels=input_label,
- mask_input=mask_input[None, :, :],
- multimask_output=False,
- )
- baseData = show_masks(masks,scores)
- # 转换为可序列化的格式
- serializable_result = {
- "code": 0,
- "data": {
- "base_64_data": convert_to_serializable(baseData)
- }
- }
-
- return serializable_result
- @router.get("/")
- async def root():
- return {"message": "Welcome to SAM2 Segmentation API"}
|