agent_core.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import copy
  4. import json
  5. import os
  6. import cv2
  7. from PIL import Image
  8. from .client_llm import send_generate_request
  9. from .client_sam3 import call_sam_service
  10. from .viz import visualize
  11. def save_debug_messages(messages_list, debug, debug_folder_path, debug_jsonl_path):
  12. """Save messages to debug jsonl file if debug is enabled"""
  13. if debug and debug_jsonl_path:
  14. # Ensure the debug directory exists before writing
  15. os.makedirs(debug_folder_path, exist_ok=True)
  16. with open(debug_jsonl_path, "w") as f:
  17. for msg in messages_list:
  18. f.write(json.dumps(msg, indent=4) + "\n")
  19. def cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path):
  20. """Clean up debug files when function successfully returns"""
  21. if debug and debug_folder_path:
  22. try:
  23. if os.path.exists(debug_jsonl_path):
  24. os.remove(debug_jsonl_path)
  25. if os.path.exists(debug_folder_path):
  26. os.rmdir(debug_folder_path)
  27. except Exception as e:
  28. print(f"Warning: Could not clean up debug files: {e}")
  29. def count_images(messages):
  30. """Count the total number of images present in the messages history."""
  31. total = 0
  32. for message in messages:
  33. # Check if message has content (should be a list)
  34. if "content" in message and isinstance(message["content"], list):
  35. # Iterate through each content item
  36. for content_item in message["content"]:
  37. # Check if content item is a dict with type "image"
  38. if (
  39. isinstance(content_item, dict)
  40. and content_item.get("type") == "image"
  41. ):
  42. total += 1
  43. return total
  44. def _prune_messages_for_next_round(
  45. messages_list,
  46. used_text_prompts,
  47. latest_sam3_text_prompt,
  48. img_path,
  49. initial_text_prompt,
  50. ):
  51. """Return a new messages list that contains only:
  52. 1) messages[:2] (with optional warning text added to the second message's content)
  53. 2) the latest assistant message (and everything after it) that contains a segment_phrase tool call
  54. """
  55. # There should not be more than 10 messages in the conversation history
  56. assert len(messages_list) < 10
  57. # Part 1: always keep the first two message JSONs
  58. part1 = copy.deepcopy(messages_list[:2])
  59. # Part 2: search backwards for the latest assistant message containing a segment_phrase tool call
  60. part2_start_idx = None
  61. for idx in range(len(messages_list) - 1, 1, -1):
  62. msg = messages_list[idx]
  63. # We only consider assistant messages with a "content" list
  64. if msg.get("role") != "assistant" or "content" not in msg:
  65. continue
  66. # Look for any content element that is a text containing the segment_phrase tool call
  67. for content in msg["content"]:
  68. if (
  69. isinstance(content, dict)
  70. and content.get("type") == "text"
  71. and "<tool>" in content.get("text", "")
  72. and "segment_phrase" in content.get("text", "")
  73. ):
  74. part2_start_idx = idx
  75. break
  76. if part2_start_idx is not None:
  77. break
  78. part2 = messages_list[part2_start_idx:] if part2_start_idx is not None else []
  79. # Part 3: decide whether to add warning text to the second message in part1
  80. previously_used = (
  81. [p for p in used_text_prompts if p != latest_sam3_text_prompt]
  82. if latest_sam3_text_prompt
  83. else list(used_text_prompts)
  84. )
  85. if part2 and len(previously_used) > 0:
  86. warning_text = f'Note that we have previously called the segment_phrase tool with each "text_prompt" in this list: {list(previously_used)}, but none of the generated results were satisfactory. So make sure that you do not use any of these phrases as the "text_prompt" to call the segment_phrase tool again.'
  87. # Replace the second message entirely to keep exactly 2 content items
  88. part1[1] = {
  89. "role": "user",
  90. "content": [
  91. {"type": "image", "image": img_path},
  92. {
  93. "type": "text",
  94. "text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'."
  95. + " "
  96. + warning_text,
  97. },
  98. ],
  99. }
  100. assert len(part1[1]["content"]) == 2
  101. # Build the new messages list: part1 (with optional warning), then part2
  102. new_messages = list(part1)
  103. new_messages.extend(part2)
  104. return new_messages
  105. def agent_inference(
  106. img_path: str,
  107. initial_text_prompt: str,
  108. debug: bool = False,
  109. send_generate_request=send_generate_request,
  110. call_sam_service=call_sam_service,
  111. max_generations: int = 100,
  112. output_dir="../../sam3_agent_out",
  113. ):
  114. """
  115. Given a text prompt and an image, this tool will perform all aspects of agentic problem solving,
  116. while saving sam3 and MLLM outputs to their respective directories.
  117. Args:
  118. img_path: Path to the input image
  119. initial_text_prompt: Initial text prompt from the user
  120. debug: Whether to enable debug mode
  121. max_generations: Maximum number of send_generate_request calls allowed (default: 100)
  122. """
  123. # setup dir
  124. sam_output_dir = os.path.join(output_dir, "sam_out")
  125. error_save_dir = os.path.join(output_dir, "none_out")
  126. debug_save_dir = os.path.join(output_dir, "agent_debug_out")
  127. os.makedirs(sam_output_dir, exist_ok=True)
  128. os.makedirs(error_save_dir, exist_ok=True)
  129. os.makedirs(debug_save_dir, exist_ok=True)
  130. current_dir = os.path.dirname(os.path.abspath(__file__))
  131. MLLM_SYSTEM_PROMPT_PATH = os.path.join(
  132. current_dir, "system_prompts/system_prompt.txt"
  133. )
  134. ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH = os.path.join(
  135. current_dir, "system_prompts/system_prompt_iterative_checking.txt"
  136. )
  137. # init variables
  138. PATH_TO_LATEST_OUTPUT_JSON = ""
  139. LATEST_SAM3_TEXT_PROMPT = ""
  140. USED_TEXT_PROMPTS = (
  141. set()
  142. ) # Track all previously used text prompts for segment_phrase
  143. generation_count = 0 # Counter for number of send_generate_request calls
  144. # debug setup
  145. debug_folder_path = None
  146. debug_jsonl_path = None
  147. if debug:
  148. debug_folder_path = os.path.join(
  149. debug_save_dir, f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}"
  150. )
  151. debug_jsonl_path = os.path.join(debug_folder_path, "debug_history.json")
  152. os.makedirs(debug_folder_path, exist_ok=True)
  153. # The helper functions are now defined outside the agent_inference function
  154. with open(MLLM_SYSTEM_PROMPT_PATH, "r") as f:
  155. system_prompt = f.read().strip()
  156. with open(ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH, "r") as f:
  157. iterative_checking_system_prompt = f.read().strip()
  158. # Construct the initial message list
  159. messages = [
  160. {"role": "system", "content": system_prompt},
  161. {
  162. "role": "user",
  163. "content": [
  164. {"type": "image", "image": img_path},
  165. {
  166. "type": "text",
  167. "text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'.",
  168. },
  169. ],
  170. },
  171. ]
  172. print(f"> Text prompt: {initial_text_prompt}")
  173. print(f"> Image path: {img_path}")
  174. print("\n\n")
  175. print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
  176. print("\n\n")
  177. generated_text = send_generate_request(messages)
  178. print(f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n")
  179. while generated_text is not None:
  180. save_debug_messages(messages, debug, debug_folder_path, debug_jsonl_path)
  181. assert (
  182. "<tool>" in generated_text,
  183. f"Generated text does not contain <tool> tag: {generated_text}",
  184. )
  185. generated_text = generated_text.split("</tool>", 1)[0] + "</tool>"
  186. tool_call_json_str = (
  187. generated_text.split("<tool>")[-1]
  188. .split("</tool>")[0]
  189. .strip()
  190. .replace(r"}}}", r"}}") # remove extra } if any
  191. )
  192. try:
  193. tool_call = json.loads(tool_call_json_str)
  194. except json.JSONDecodeError:
  195. raise ValueError(f"Invalid JSON in tool call: {tool_call_json_str}")
  196. if PATH_TO_LATEST_OUTPUT_JSON == "":
  197. # The first tool call must be segment_phrase or report_no_mask
  198. assert (
  199. tool_call["name"] == "segment_phrase"
  200. or tool_call["name"] == "report_no_mask"
  201. )
  202. if tool_call["name"] == "segment_phrase":
  203. print("🔍 Calling segment_phrase tool...")
  204. assert list(tool_call["parameters"].keys()) == ["text_prompt"]
  205. # Check if this text_prompt has been used before
  206. current_text_prompt = tool_call["parameters"]["text_prompt"]
  207. if current_text_prompt in USED_TEXT_PROMPTS:
  208. print(
  209. f"❌ Text prompt '{current_text_prompt}' has been used before. Requesting a different prompt."
  210. )
  211. duplicate_prompt_message = f"You have previously used '{current_text_prompt}' as your text_prompt to call the segment_phrase tool. You may not use it again. Please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase prompt, while adhering to all the rules stated in the system prompt. You must also never use any of the following text_prompt(s): {str(list(USED_TEXT_PROMPTS))}."
  212. messages.append(
  213. {
  214. "role": "assistant",
  215. "content": [{"type": "text", "text": generated_text}],
  216. }
  217. )
  218. messages.append(
  219. {
  220. "role": "user",
  221. "content": [{"type": "text", "text": duplicate_prompt_message}],
  222. }
  223. )
  224. else:
  225. # Add the text_prompt to the set of used prompts
  226. USED_TEXT_PROMPTS.add(current_text_prompt)
  227. LATEST_SAM3_TEXT_PROMPT = current_text_prompt
  228. PATH_TO_LATEST_OUTPUT_JSON = call_sam_service(
  229. image_path=img_path,
  230. text_prompt=current_text_prompt,
  231. output_folder_path=sam_output_dir,
  232. )
  233. sam3_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
  234. sam3_output_image_path = sam3_outputs["output_image_path"]
  235. num_masks = len(sam3_outputs["pred_boxes"])
  236. messages.append(
  237. {
  238. "role": "assistant",
  239. "content": [{"type": "text", "text": generated_text}],
  240. }
  241. )
  242. if num_masks == 0:
  243. print("❌ No masks generated by SAM3, reporting no mask to Qwen.")
  244. sam3_output_text_message = f"The segment_phrase tool did not generate any masks for the text_prompt '{current_text_prompt}'. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt. Please be reminded that the original user query was '{initial_text_prompt}'."
  245. messages.append(
  246. {
  247. "role": "user",
  248. "content": [
  249. {"type": "text", "text": sam3_output_text_message}
  250. ],
  251. }
  252. )
  253. else:
  254. sam3_output_text_message = rf"The segment_phrase tool generated {num_masks} available masks. All {num_masks} available masks are rendered in this image below, now you must analyze the {num_masks} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action. Please be reminded that the original user query was '{initial_text_prompt}'."
  255. messages.append(
  256. {
  257. "role": "user",
  258. "content": [
  259. {"type": "text", "text": sam3_output_text_message},
  260. {"type": "image", "image": sam3_output_image_path},
  261. ],
  262. }
  263. )
  264. print("\n\n>>> sam3_output_text_message:\n", sam3_output_text_message)
  265. elif tool_call["name"] == "examine_each_mask":
  266. print("🔍 Calling examine_each_mask tool...")
  267. assert LATEST_SAM3_TEXT_PROMPT != ""
  268. # Make sure that the last message is a image
  269. assert messages[-1]["content"][1]["type"] == "image", (
  270. "Second content element should be an image"
  271. )
  272. messages.pop() # Remove the last user message
  273. # Add simplified replacement message
  274. simplified_message = {
  275. "role": "user",
  276. "content": [
  277. {
  278. "type": "text",
  279. "text": "The segment_phrase tool generated several masks. Now you must analyze the mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
  280. }
  281. ],
  282. }
  283. messages.append(simplified_message)
  284. current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
  285. num_masks = len(current_outputs["pred_masks"])
  286. masks_to_keep = []
  287. # MLLM check the mask one by one
  288. for i in range(num_masks):
  289. print(f"🔍 Checking mask {i + 1}/{num_masks}...")
  290. image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i)
  291. image_w_zoomed_in_mask_i_path = os.path.join(
  292. sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
  293. ).replace(".png", f"_zoom_in_mask_{i + 1}.png")
  294. image_w_mask_i_path = os.path.join(
  295. sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
  296. ).replace(".png", f"_selected_mask_{i + 1}.png")
  297. image_w_zoomed_in_mask_i.save(image_w_zoomed_in_mask_i_path)
  298. image_w_mask_i.save(image_w_mask_i_path)
  299. iterative_checking_messages = [
  300. {"role": "system", "content": iterative_checking_system_prompt},
  301. {
  302. "role": "user",
  303. "content": [
  304. {"type": "text", "text": f"The raw input image: "},
  305. {"type": "image", "image": img_path},
  306. {
  307. "type": "text",
  308. "text": f"The initial user input query is: '{initial_text_prompt}'",
  309. },
  310. {
  311. "type": "text",
  312. "text": f"Image with the predicted segmentation mask rendered on it: ",
  313. },
  314. {"type": "image", "image": image_w_mask_i_path},
  315. {
  316. "type": "text",
  317. "text": f"Image with the zoomed-in mask: ",
  318. },
  319. {"type": "image", "image": image_w_zoomed_in_mask_i_path},
  320. ],
  321. },
  322. ]
  323. checking_generated_text = send_generate_request(
  324. iterative_checking_messages
  325. )
  326. # Process the generated text to determine if the mask should be kept or rejected
  327. if checking_generated_text is None:
  328. raise ValueError(
  329. "Generated text is None, which is unexpected. Please check the Qwen server and the input parameters."
  330. )
  331. print(f"Generated text for mask {i + 1}: {checking_generated_text}")
  332. verdict = (
  333. checking_generated_text.split("<verdict>")[-1]
  334. .split("</verdict>")[0]
  335. .strip()
  336. )
  337. if "Accept" in verdict:
  338. assert not "Reject" in verdict
  339. print(f"Mask {i + 1} accepted, keeping it in the outputs.")
  340. masks_to_keep.append(i)
  341. elif "Reject" in verdict:
  342. assert not "Accept" in verdict
  343. print(f"Mask {i + 1} rejected, removing it from the outputs.")
  344. else:
  345. raise ValueError(
  346. f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'."
  347. )
  348. updated_outputs = {
  349. "original_image_path": current_outputs["original_image_path"],
  350. "orig_img_h": current_outputs["orig_img_h"],
  351. "orig_img_w": current_outputs["orig_img_w"],
  352. "pred_boxes": [current_outputs["pred_boxes"][i] for i in masks_to_keep],
  353. "pred_scores": [
  354. current_outputs["pred_scores"][i] for i in masks_to_keep
  355. ],
  356. "pred_masks": [current_outputs["pred_masks"][i] for i in masks_to_keep],
  357. }
  358. image_w_check_masks = visualize(updated_outputs)
  359. image_w_check_masks_path = os.path.join(
  360. sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png"
  361. ).replace(
  362. ".png",
  363. f"_selected_masks_{'-'.join(map(str, [i + 1 for i in masks_to_keep]))}.png".replace(
  364. "/", "_"
  365. ),
  366. )
  367. image_w_check_masks.save(image_w_check_masks_path)
  368. # save the updated json outputs and append to message history
  369. messages.append(
  370. {
  371. "role": "assistant",
  372. "content": [{"type": "text", "text": generated_text}],
  373. }
  374. )
  375. if len(masks_to_keep) == 0:
  376. messages.append(
  377. {
  378. "role": "user",
  379. "content": [
  380. {
  381. "type": "text",
  382. "text": f"The original user query was: '{initial_text_prompt}'. The examine_each_mask tool examined and rejected all of the masks generated by the segment_phrase tool. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt.",
  383. }
  384. ],
  385. }
  386. )
  387. else:
  388. messages.append(
  389. {
  390. "role": "user",
  391. "content": [
  392. {
  393. "type": "text",
  394. "text": f"The original user query was: '{initial_text_prompt}'. After calling the examine_each_mask tool on the available masks, the number of available masks is now {len(masks_to_keep)}. All {len(masks_to_keep)} available masks are rendered in this image below, now you must analyze the {len(masks_to_keep)} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
  395. },
  396. {"type": "image", "image": image_w_check_masks_path},
  397. ],
  398. }
  399. )
  400. # Create a new filename based on the original path to avoid filename length issues
  401. base_path = PATH_TO_LATEST_OUTPUT_JSON
  402. # Remove any existing "masks_" suffix to avoid duplication
  403. if "masks_" in base_path:
  404. base_path = base_path.split("masks_")[0] + ".json"
  405. # Create new filename with current masks; use a clearer suffix when empty
  406. if len(masks_to_keep) == 0:
  407. PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
  408. ".json", "masks_none.json"
  409. )
  410. else:
  411. PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
  412. ".json", f"masks_{'_'.join(map(str, masks_to_keep))}.json"
  413. )
  414. json.dump(updated_outputs, open(PATH_TO_LATEST_OUTPUT_JSON, "w"), indent=4)
  415. elif tool_call["name"] == "select_masks_and_return":
  416. print("🔍 Calling select_masks_and_return tool...")
  417. current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
  418. assert list(tool_call["parameters"].keys()) == ["final_answer_masks"]
  419. masks_to_keep = tool_call["parameters"]["final_answer_masks"]
  420. # Keep only valid mask indices, remove duplicates, and preserve deterministic ascending order
  421. available_masks = set(range(1, len(current_outputs["pred_masks"]) + 1))
  422. masks_to_keep = sorted({i for i in masks_to_keep if i in available_masks})
  423. # Change this to a update message telling the model to try again along with information about errors made.
  424. final_outputs = {
  425. "original_image_path": current_outputs["original_image_path"],
  426. "orig_img_h": current_outputs["orig_img_h"],
  427. "orig_img_w": current_outputs["orig_img_w"],
  428. "pred_boxes": [
  429. current_outputs["pred_boxes"][i - 1] for i in masks_to_keep
  430. ],
  431. "pred_scores": [
  432. current_outputs["pred_scores"][i - 1] for i in masks_to_keep
  433. ],
  434. "pred_masks": [
  435. current_outputs["pred_masks"][i - 1] for i in masks_to_keep
  436. ],
  437. }
  438. rendered_final_output = visualize(final_outputs)
  439. messages.append(
  440. {
  441. "role": "assistant",
  442. "content": [{"type": "text", "text": generated_text}],
  443. }
  444. )
  445. # Clean up debug files before successful return
  446. cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path)
  447. return messages, final_outputs, rendered_final_output
  448. elif tool_call["name"] == "report_no_mask":
  449. print("🔍 Calling report_no_mask tool...")
  450. height, width = cv2.imread(img_path).shape[:2]
  451. final_outputs = {
  452. "original_image_path": img_path,
  453. "orig_img_h": height,
  454. "orig_img_w": width,
  455. "pred_boxes": [],
  456. "pred_scores": [],
  457. "pred_masks": [],
  458. }
  459. rendered_final_output = Image.open(img_path)
  460. messages.append(
  461. {
  462. "role": "assistant",
  463. "content": [{"type": "text", "text": generated_text}],
  464. }
  465. )
  466. return messages, final_outputs, rendered_final_output
  467. else:
  468. raise ValueError(f"Unknown tool call: {tool_call['name']}")
  469. # sometimes the MLLM don't know when to stop, and generates multiple tool calls in one round, so we need to split the generated text by </tool> and only keep the first one
  470. for message in messages:
  471. if message["role"] == "assistant" and "content" in message:
  472. for content in message["content"]:
  473. if (
  474. isinstance(content, dict)
  475. and content.get("type") == "text"
  476. and "text" in content
  477. ):
  478. content["text"] = (
  479. content["text"].split("</tool>", 1)[0] + "</tool>\n\n"
  480. )
  481. # Prune the messages history before the next MLLM generation round according to the 3-part rules.
  482. # This keeps history compact and ensures the model sees only the allowed parts.
  483. messages = _prune_messages_for_next_round(
  484. messages,
  485. USED_TEXT_PROMPTS,
  486. LATEST_SAM3_TEXT_PROMPT,
  487. img_path,
  488. initial_text_prompt,
  489. )
  490. # make sure there can never be more than 2 images in the context
  491. assert count_images(messages) <= 2
  492. generation_count += 1
  493. if generation_count > max_generations:
  494. raise ValueError(
  495. f"Exceeded maximum number of allowed generation requests ({max_generations})"
  496. )
  497. print("\n\n")
  498. print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
  499. print("\n\n")
  500. generated_text = send_generate_request(messages)
  501. print(
  502. f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n"
  503. )
  504. print("\n\n>>> SAM 3 Agent execution ended.\n\n")
  505. error_save_path = os.path.join(
  506. error_save_dir,
  507. f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}_error_history.json",
  508. )
  509. with open(error_save_path, "w") as f:
  510. json.dump(messages, f, indent=4)
  511. print("Saved messages history that caused error to:", error_save_path)
  512. raise ValueError(
  513. rf"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters for image path: {img_path} and initial text prompt: {initial_text_prompt}."
  514. )