inference.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import json
  4. import os
  5. from sam3.agent.agent_core import agent_inference
  6. def run_single_image_inference(
  7. image_path,
  8. text_prompt,
  9. llm_config,
  10. send_generate_request,
  11. call_sam_service,
  12. output_dir="agent_output",
  13. debug=False,
  14. ):
  15. """Run inference on a single image with provided prompt"""
  16. llm_name = llm_config["name"]
  17. if not os.path.exists(image_path):
  18. raise FileNotFoundError(f"Image file not found: {image_path}")
  19. # Create output directory
  20. os.makedirs(output_dir, exist_ok=True)
  21. # Generate output file names
  22. image_basename = os.path.splitext(os.path.basename(image_path))[0]
  23. prompt_for_filename = text_prompt.replace("/", "_").replace(" ", "_")
  24. base_filename = f"{image_basename}_{prompt_for_filename}_agent_{llm_name}"
  25. output_json_path = os.path.join(output_dir, f"{base_filename}_pred.json")
  26. output_image_path = os.path.join(output_dir, f"{base_filename}_pred.png")
  27. agent_history_path = os.path.join(output_dir, f"{base_filename}_history.json")
  28. # Check if output already exists and skip
  29. if os.path.exists(output_json_path):
  30. print(f"Output JSON {output_json_path} already exists. Skipping.")
  31. return
  32. print(f"{'-' * 30} Starting SAM 3 Agent Session... {'-' * 30} ")
  33. agent_history, final_output_dict, rendered_final_output = agent_inference(
  34. image_path,
  35. text_prompt,
  36. send_generate_request=send_generate_request,
  37. call_sam_service=call_sam_service,
  38. output_dir=output_dir,
  39. debug=debug,
  40. )
  41. print(f"{'-' * 30} End of SAM 3 Agent Session... {'-' * 30} ")
  42. final_output_dict["text_prompt"] = text_prompt
  43. final_output_dict["image_path"] = image_path
  44. # Save outputs
  45. json.dump(final_output_dict, open(output_json_path, "w"), indent=4)
  46. json.dump(agent_history, open(agent_history_path, "w"), indent=4)
  47. rendered_final_output.save(output_image_path)
  48. print(f"\n✅ Successfully processed single image!")
  49. print(f"Output JSON: {output_json_path}")
  50. print(f"Output Image: {output_image_path}")
  51. print(f"Agent History: {agent_history_path}")
  52. return output_image_path