client_llm.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
  2. # pyre-unsafe
  3. import base64
  4. import os
  5. from typing import Any, Optional
  6. from openai import OpenAI
  7. def get_image_base64_and_mime(image_path):
  8. """Convert image file to base64 string and get MIME type"""
  9. try:
  10. # Get MIME type based on file extension
  11. ext = os.path.splitext(image_path)[1].lower()
  12. mime_types = {
  13. ".jpg": "image/jpeg",
  14. ".jpeg": "image/jpeg",
  15. ".png": "image/png",
  16. ".gif": "image/gif",
  17. ".webp": "image/webp",
  18. ".bmp": "image/bmp",
  19. }
  20. mime_type = mime_types.get(ext, "image/jpeg") # Default to JPEG
  21. # Convert image to base64
  22. with open(image_path, "rb") as image_file:
  23. base64_data = base64.b64encode(image_file.read()).decode("utf-8")
  24. return base64_data, mime_type
  25. except Exception as e:
  26. print(f"Error converting image to base64: {e}")
  27. return None, None
  28. def send_generate_request(
  29. messages,
  30. server_url=None,
  31. model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
  32. api_key=None,
  33. max_tokens=4096,
  34. ):
  35. """
  36. Sends a request to the OpenAI-compatible API endpoint using the OpenAI client library.
  37. Args:
  38. server_url (str): The base URL of the server, e.g. "http://127.0.0.1:8000"
  39. messages (list): A list of message dicts, each containing role and content.
  40. model (str): The model to use for generation (default: "llama-4")
  41. max_tokens (int): Maximum number of tokens to generate (default: 4096)
  42. Returns:
  43. str: The generated response text from the server.
  44. """
  45. # Process messages to convert image paths to base64
  46. processed_messages = []
  47. for message in messages:
  48. processed_message = message.copy()
  49. if message["role"] == "user" and "content" in message:
  50. processed_content = []
  51. for c in message["content"]:
  52. if isinstance(c, dict) and c.get("type") == "image":
  53. # Convert image path to base64 format
  54. image_path = c["image"]
  55. print("image_path", image_path)
  56. new_image_path = image_path.replace(
  57. "?", "%3F"
  58. ) # Escape ? in the path
  59. # Read the image file and convert to base64
  60. try:
  61. base64_image, mime_type = get_image_base64_and_mime(
  62. new_image_path
  63. )
  64. if base64_image is None:
  65. print(
  66. f"Warning: Could not convert image to base64: {new_image_path}"
  67. )
  68. continue
  69. # Create the proper image_url structure with base64 data
  70. processed_content.append(
  71. {
  72. "type": "image_url",
  73. "image_url": {
  74. "url": f"data:{mime_type};base64,{base64_image}",
  75. "detail": "high",
  76. },
  77. }
  78. )
  79. except FileNotFoundError:
  80. print(f"Warning: Image file not found: {new_image_path}")
  81. continue
  82. except Exception as e:
  83. print(f"Warning: Error processing image {new_image_path}: {e}")
  84. continue
  85. else:
  86. processed_content.append(c)
  87. processed_message["content"] = processed_content
  88. processed_messages.append(processed_message)
  89. # Create OpenAI client with custom base URL
  90. client = OpenAI(api_key=api_key, base_url=server_url)
  91. try:
  92. print(f"🔍 Calling model {model}...")
  93. response = client.chat.completions.create(
  94. model=model,
  95. messages=processed_messages,
  96. max_completion_tokens=max_tokens,
  97. n=1,
  98. )
  99. # print(f"Received response: {response.choices[0].message}")
  100. # Extract the response content
  101. if response.choices and len(response.choices) > 0:
  102. return response.choices[0].message.content
  103. else:
  104. print(f"Unexpected response format: {response}")
  105. return None
  106. except Exception as e:
  107. print(f"Request failed: {e}")
  108. return None
  109. def send_direct_request(
  110. llm: Any,
  111. messages: list[dict[str, Any]],
  112. sampling_params: Any,
  113. ) -> Optional[str]:
  114. """
  115. Run inference on a vLLM model instance directly without using a server.
  116. Args:
  117. llm: Initialized vLLM LLM instance (passed from external initialization)
  118. messages: List of message dicts with role and content (OpenAI format)
  119. sampling_params: vLLM SamplingParams instance (initialized externally)
  120. Returns:
  121. str: Generated response text, or None if inference fails
  122. """
  123. try:
  124. # Process messages to handle images (convert to base64 if needed)
  125. processed_messages = []
  126. for message in messages:
  127. processed_message = message.copy()
  128. if message["role"] == "user" and "content" in message:
  129. processed_content = []
  130. for c in message["content"]:
  131. if isinstance(c, dict) and c.get("type") == "image":
  132. # Convert image path to base64 format
  133. image_path = c["image"]
  134. new_image_path = image_path.replace("?", "%3F")
  135. try:
  136. base64_image, mime_type = get_image_base64_and_mime(
  137. new_image_path
  138. )
  139. if base64_image is None:
  140. print(
  141. f"Warning: Could not convert image: {new_image_path}"
  142. )
  143. continue
  144. # vLLM expects image_url format
  145. processed_content.append(
  146. {
  147. "type": "image_url",
  148. "image_url": {
  149. "url": f"data:{mime_type};base64,{base64_image}"
  150. },
  151. }
  152. )
  153. except Exception as e:
  154. print(
  155. f"Warning: Error processing image {new_image_path}: {e}"
  156. )
  157. continue
  158. else:
  159. processed_content.append(c)
  160. processed_message["content"] = processed_content
  161. processed_messages.append(processed_message)
  162. print("🔍 Running direct inference with vLLM...")
  163. # Run inference using vLLM's chat interface
  164. outputs = llm.chat(
  165. messages=processed_messages,
  166. sampling_params=sampling_params,
  167. )
  168. # Extract the generated text from the first output
  169. if outputs and len(outputs) > 0:
  170. generated_text = outputs[0].outputs[0].text
  171. return generated_text
  172. else:
  173. print(f"Unexpected output format: {outputs}")
  174. return None
  175. except Exception as e:
  176. print(f"Direct inference failed: {e}")
  177. return None