| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import json
- import os
- import torch
- from PIL import Image
- from sam3.model.box_ops import box_xyxy_to_xywh
- from sam3.train.masks_ops import rle_encode
- from .helpers.mask_overlap_removal import remove_overlapping_masks
- from .viz import visualize
- def sam3_inference(processor, image_path, text_prompt):
- """Run SAM 3 image inference with text prompts and format the outputs"""
- image = Image.open(image_path)
- orig_img_w, orig_img_h = image.size
- # model inference
- inference_state = processor.set_image(image)
- inference_state = processor.set_text_prompt(
- state=inference_state, prompt=text_prompt
- )
- # format and assemble outputs
- pred_boxes_xyxy = torch.stack(
- [
- inference_state["boxes"][:, 0] / orig_img_w,
- inference_state["boxes"][:, 1] / orig_img_h,
- inference_state["boxes"][:, 2] / orig_img_w,
- inference_state["boxes"][:, 3] / orig_img_h,
- ],
- dim=-1,
- ) # normalized in range [0, 1]
- pred_boxes_xywh = box_xyxy_to_xywh(pred_boxes_xyxy).tolist()
- pred_masks = rle_encode(inference_state["masks"].squeeze(1))
- pred_masks = [m["counts"] for m in pred_masks]
- outputs = {
- "orig_img_h": orig_img_h,
- "orig_img_w": orig_img_w,
- "pred_boxes": pred_boxes_xywh,
- "pred_masks": pred_masks,
- "pred_scores": inference_state["scores"].tolist(),
- }
- return outputs
- def call_sam_service(
- sam3_processor,
- image_path: str,
- text_prompt: str,
- output_folder_path: str = "sam3_output",
- ):
- """
- Loads an image, sends it with a text prompt to the service,
- saves the results, and renders the visualization.
- """
- print(f"📞 Loading image '{image_path}' and sending with prompt '{text_prompt}'...")
- text_prompt_for_save_path = (
- text_prompt.replace("/", "_") if "/" in text_prompt else text_prompt
- )
- os.makedirs(
- os.path.join(output_folder_path, image_path.replace("/", "-")), exist_ok=True
- )
- output_json_path = os.path.join(
- output_folder_path,
- image_path.replace("/", "-"),
- rf"{text_prompt_for_save_path}.json",
- )
- output_image_path = os.path.join(
- output_folder_path,
- image_path.replace("/", "-"),
- rf"{text_prompt_for_save_path}.png",
- )
- try:
- # Send the image and text prompt as a multipart/form-data request
- serialized_response = sam3_inference(sam3_processor, image_path, text_prompt)
- # 1. Prepare the response dictionary
- serialized_response = remove_overlapping_masks(serialized_response)
- serialized_response = {
- "original_image_path": image_path,
- "output_image_path": output_image_path,
- **serialized_response,
- }
- # 2. Reorder predictions by scores (highest to lowest) if scores are available
- if "pred_scores" in serialized_response and serialized_response["pred_scores"]:
- # Create indices sorted by scores in descending order
- score_indices = sorted(
- range(len(serialized_response["pred_scores"])),
- key=lambda i: serialized_response["pred_scores"][i],
- reverse=True,
- )
- # Reorder all three lists based on the sorted indices
- serialized_response["pred_scores"] = [
- serialized_response["pred_scores"][i] for i in score_indices
- ]
- serialized_response["pred_boxes"] = [
- serialized_response["pred_boxes"][i] for i in score_indices
- ]
- serialized_response["pred_masks"] = [
- serialized_response["pred_masks"][i] for i in score_indices
- ]
- # 3. Remove any invalid RLE masks that is too short (shorter than 5 characters)
- valid_masks = []
- valid_boxes = []
- valid_scores = []
- for i, rle in enumerate(serialized_response["pred_masks"]):
- if len(rle) > 4:
- valid_masks.append(rle)
- valid_boxes.append(serialized_response["pred_boxes"][i])
- valid_scores.append(serialized_response["pred_scores"][i])
- serialized_response["pred_masks"] = valid_masks
- serialized_response["pred_boxes"] = valid_boxes
- serialized_response["pred_scores"] = valid_scores
- with open(output_json_path, "w") as f:
- json.dump(serialized_response, f, indent=4)
- print(f"✅ Raw JSON response saved to '{output_json_path}'")
- # 4. Render and save visualizations on the image and save it in the SAM3 output folder
- print("🔍 Rendering visualizations on the image ...")
- viz_image = visualize(serialized_response)
- os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
- viz_image.save(output_image_path)
- print("✅ Saved visualization at:", output_image_path)
- except Exception as e:
- print(f"❌ Error calling service: {e}")
- return output_json_path
|