segmentation.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import warnings
  2. warnings.filterwarnings("ignore")
  3. from fastapi import APIRouter, HTTPException, UploadFile, File
  4. from fastapi.responses import JSONResponse
  5. from .models import (
  6. SegmentationRequest,
  7. # SegmentationWithBoxRequest,
  8. # SegmentationResponse,
  9. # PointCoords
  10. )
  11. # from segment_anything_hq import sam_model_registry, SamPredictor
  12. from typing import List
  13. import torch
  14. from PIL import Image
  15. import requests
  16. import numpy as np
  17. import base64
  18. import io,os
  19. from datetime import datetime
  20. # from sam2.build_sam import build_sam2
  21. # from sam2.sam2_image_predictor import SAM2ImagePredictor
  22. from utils.sam_utils import show_masks,convert_to_serializable
  23. import sam3
  24. from sam3 import build_sam3_image_model
  25. from sam3.model.sam3_image_processor import Sam3Processor
  26. router = APIRouter()
  27. # 初始化模型(应用启动时加载一次)
  28. device = "cuda" if torch.cuda.is_available() else "cpu"
  29. torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
  30. # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
  31. if torch.cuda.get_device_properties(0).major >= 8:
  32. torch.backends.cuda.matmul.allow_tf32 = True
  33. torch.backends.cudnn.allow_tf32 = True
  34. # 配置PyTorch设置
  35. if torch.cuda.is_available():
  36. # 设置默认dtype为float32以避免bfloat16问题
  37. torch.set_default_tensor_type('torch.cuda.FloatTensor')
  38. else:
  39. torch.set_default_tensor_type('torch.FloatTensor')
  40. # 关闭autocast以避免bfloat16问题
  41. torch.set_autocast_enabled(False)
  42. # sam1模型配置
  43. # sam_checkpoint = "./models/sam_hq_vit_h.pth"
  44. # model_type = "vit_h"
  45. # sam1_model = sam_model_registry[model_type](checkpoint=sam_checkpoint)
  46. # sam1_model = sam1_model.float() # 强制转换为float32
  47. # sam1_model.to(device=device)
  48. # # sam2模型配置
  49. # sam2_checkpoint = "/root/.cache/modelscope/hub/models/facebook/sam2.1-hiera-large/sam2.1_hiera_large.pt"
  50. # model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
  51. # sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
  52. # sam3
  53. sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
  54. bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
  55. sam3_checkpoint = "/root/.cache/modelscope/hub/models/facebook/sam3/sam3.pt"
  56. sam3_model = build_sam3_image_model(device=device,bpe_path=bpe_path, enable_inst_interactivity=True,load_from_HF=False,checkpoint_path=sam3_checkpoint)
  57. print("device", device)
  58. @router.post("/segment_with_points")
  59. async def segment_with_points(request: SegmentationRequest):
  60. image_url = request.image_url
  61. raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
  62. # input_point = np.array([[296, 543], [150, 543], [296, 683], [150, 683]])
  63. # input_label = np.array([1, 0, 1, 1])
  64. input_point = np.array(request.points)
  65. input_label = np.array(request.labels)
  66. # match request.type:
  67. # case 0:
  68. # predictor = SamPredictor(sam1_model)
  69. # raw_image = np.array(raw_image)
  70. # predictor.set_image(raw_image)
  71. # with torch.cuda.amp.autocast(enabled=False): # 临时禁用autocast
  72. # masks, scores, logits = predictor.predict(
  73. # point_coords=input_point,
  74. # point_labels=input_label,
  75. # multimask_output=True,
  76. # hq_token_only=True,
  77. # )
  78. # mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
  79. # masks, scores, _ = predictor.predict(
  80. # point_coords=input_point,
  81. # point_labels=input_label,
  82. # mask_input=mask_input[None, :, :],
  83. # multimask_output=False,
  84. # hq_token_only=True,
  85. # )
  86. # case 1:
  87. # # sam2
  88. # predictor = SAM2ImagePredictor(sam2_model)
  89. # predictor.set_image(raw_image)
  90. # masks, scores, logits = predictor.predict(
  91. # point_coords=input_point,
  92. # point_labels=input_label,
  93. # multimask_output=True,
  94. # )
  95. # mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
  96. # masks, scores, _ = predictor.predict(
  97. # point_coords=input_point,
  98. # point_labels=input_label,
  99. # mask_input=mask_input[None, :, :],
  100. # multimask_output=False,
  101. # )
  102. # case 2:
  103. # # sam3
  104. # processor = Sam3Processor(sam3_model,device=device)
  105. # inference_state = processor.set_image(raw_image)
  106. # masks, scores, logits = sam3_model.predict_inst(
  107. # inference_state,
  108. # point_coords=input_point,
  109. # point_labels=input_label,
  110. # multimask_output=True,
  111. # )
  112. # mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
  113. # masks, scores, logits = sam3_model.predict_inst(
  114. # inference_state,
  115. # point_coords=input_point,
  116. # point_labels=input_label,
  117. # mask_input=mask_input[None, :, :],
  118. # multimask_output=False,
  119. # )
  120. processor = Sam3Processor(sam3_model,device=device)
  121. inference_state = processor.set_image(raw_image)
  122. masks, scores, logits = sam3_model.predict_inst(
  123. inference_state,
  124. point_coords=input_point,
  125. point_labels=input_label,
  126. multimask_output=True,
  127. )
  128. mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
  129. masks, scores, logits = sam3_model.predict_inst(
  130. inference_state,
  131. point_coords=input_point,
  132. point_labels=input_label,
  133. mask_input=mask_input[None, :, :],
  134. multimask_output=False,
  135. )
  136. baseData = show_masks(masks,scores)
  137. # 转换为可序列化的格式
  138. serializable_result = {
  139. "code": 0,
  140. "data": {
  141. "base_64_data": convert_to_serializable(baseData)
  142. }
  143. }
  144. return serializable_result
  145. @router.get("/")
  146. async def root():
  147. return {"message": "Welcome to SAM2 Segmentation API"}