client_sam3.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import json
  4. import os
  5. import torch
  6. from PIL import Image
  7. from sam3.model.box_ops import box_xyxy_to_xywh
  8. from sam3.train.masks_ops import rle_encode
  9. from .helpers.mask_overlap_removal import remove_overlapping_masks
  10. from .viz import visualize
  11. def sam3_inference(processor, image_path, text_prompt):
  12. """Run SAM 3 image inference with text prompts and format the outputs"""
  13. image = Image.open(image_path)
  14. orig_img_w, orig_img_h = image.size
  15. # model inference
  16. inference_state = processor.set_image(image)
  17. inference_state = processor.set_text_prompt(
  18. state=inference_state, prompt=text_prompt
  19. )
  20. # format and assemble outputs
  21. pred_boxes_xyxy = torch.stack(
  22. [
  23. inference_state["boxes"][:, 0] / orig_img_w,
  24. inference_state["boxes"][:, 1] / orig_img_h,
  25. inference_state["boxes"][:, 2] / orig_img_w,
  26. inference_state["boxes"][:, 3] / orig_img_h,
  27. ],
  28. dim=-1,
  29. ) # normalized in range [0, 1]
  30. pred_boxes_xywh = box_xyxy_to_xywh(pred_boxes_xyxy).tolist()
  31. pred_masks = rle_encode(inference_state["masks"].squeeze(1))
  32. pred_masks = [m["counts"] for m in pred_masks]
  33. outputs = {
  34. "orig_img_h": orig_img_h,
  35. "orig_img_w": orig_img_w,
  36. "pred_boxes": pred_boxes_xywh,
  37. "pred_masks": pred_masks,
  38. "pred_scores": inference_state["scores"].tolist(),
  39. }
  40. return outputs
  41. def call_sam_service(
  42. sam3_processor,
  43. image_path: str,
  44. text_prompt: str,
  45. output_folder_path: str = "sam3_output",
  46. ):
  47. """
  48. Loads an image, sends it with a text prompt to the service,
  49. saves the results, and renders the visualization.
  50. """
  51. print(f"📞 Loading image '{image_path}' and sending with prompt '{text_prompt}'...")
  52. text_prompt_for_save_path = (
  53. text_prompt.replace("/", "_") if "/" in text_prompt else text_prompt
  54. )
  55. os.makedirs(
  56. os.path.join(output_folder_path, image_path.replace("/", "-")), exist_ok=True
  57. )
  58. output_json_path = os.path.join(
  59. output_folder_path,
  60. image_path.replace("/", "-"),
  61. rf"{text_prompt_for_save_path}.json",
  62. )
  63. output_image_path = os.path.join(
  64. output_folder_path,
  65. image_path.replace("/", "-"),
  66. rf"{text_prompt_for_save_path}.png",
  67. )
  68. try:
  69. # Send the image and text prompt as a multipart/form-data request
  70. serialized_response = sam3_inference(sam3_processor, image_path, text_prompt)
  71. # 1. Prepare the response dictionary
  72. serialized_response = remove_overlapping_masks(serialized_response)
  73. serialized_response = {
  74. "original_image_path": image_path,
  75. "output_image_path": output_image_path,
  76. **serialized_response,
  77. }
  78. # 2. Reorder predictions by scores (highest to lowest) if scores are available
  79. if "pred_scores" in serialized_response and serialized_response["pred_scores"]:
  80. # Create indices sorted by scores in descending order
  81. score_indices = sorted(
  82. range(len(serialized_response["pred_scores"])),
  83. key=lambda i: serialized_response["pred_scores"][i],
  84. reverse=True,
  85. )
  86. # Reorder all three lists based on the sorted indices
  87. serialized_response["pred_scores"] = [
  88. serialized_response["pred_scores"][i] for i in score_indices
  89. ]
  90. serialized_response["pred_boxes"] = [
  91. serialized_response["pred_boxes"][i] for i in score_indices
  92. ]
  93. serialized_response["pred_masks"] = [
  94. serialized_response["pred_masks"][i] for i in score_indices
  95. ]
  96. # 3. Remove any invalid RLE masks that is too short (shorter than 5 characters)
  97. valid_masks = []
  98. valid_boxes = []
  99. valid_scores = []
  100. for i, rle in enumerate(serialized_response["pred_masks"]):
  101. if len(rle) > 4:
  102. valid_masks.append(rle)
  103. valid_boxes.append(serialized_response["pred_boxes"][i])
  104. valid_scores.append(serialized_response["pred_scores"][i])
  105. serialized_response["pred_masks"] = valid_masks
  106. serialized_response["pred_boxes"] = valid_boxes
  107. serialized_response["pred_scores"] = valid_scores
  108. with open(output_json_path, "w") as f:
  109. json.dump(serialized_response, f, indent=4)
  110. print(f"✅ Raw JSON response saved to '{output_json_path}'")
  111. # 4. Render and save visualizations on the image and save it in the SAM3 output folder
  112. print("🔍 Rendering visualizations on the image ...")
  113. viz_image = visualize(serialized_response)
  114. os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
  115. viz_image.save(output_image_path)
  116. print("✅ Saved visualization at:", output_image_path)
  117. except Exception as e:
  118. print(f"❌ Error calling service: {e}")
  119. return output_json_path