| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
- # pyre-unsafe
- import base64
- import os
- from typing import Any, Optional
- from openai import OpenAI
- def get_image_base64_and_mime(image_path):
- """Convert image file to base64 string and get MIME type"""
- try:
- # Get MIME type based on file extension
- ext = os.path.splitext(image_path)[1].lower()
- mime_types = {
- ".jpg": "image/jpeg",
- ".jpeg": "image/jpeg",
- ".png": "image/png",
- ".gif": "image/gif",
- ".webp": "image/webp",
- ".bmp": "image/bmp",
- }
- mime_type = mime_types.get(ext, "image/jpeg") # Default to JPEG
- # Convert image to base64
- with open(image_path, "rb") as image_file:
- base64_data = base64.b64encode(image_file.read()).decode("utf-8")
- return base64_data, mime_type
- except Exception as e:
- print(f"Error converting image to base64: {e}")
- return None, None
- def send_generate_request(
- messages,
- server_url=None,
- model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
- api_key=None,
- max_tokens=4096,
- ):
- """
- Sends a request to the OpenAI-compatible API endpoint using the OpenAI client library.
- Args:
- server_url (str): The base URL of the server, e.g. "http://127.0.0.1:8000"
- messages (list): A list of message dicts, each containing role and content.
- model (str): The model to use for generation (default: "llama-4")
- max_tokens (int): Maximum number of tokens to generate (default: 4096)
- Returns:
- str: The generated response text from the server.
- """
- # Process messages to convert image paths to base64
- processed_messages = []
- for message in messages:
- processed_message = message.copy()
- if message["role"] == "user" and "content" in message:
- processed_content = []
- for c in message["content"]:
- if isinstance(c, dict) and c.get("type") == "image":
- # Convert image path to base64 format
- image_path = c["image"]
- print("image_path", image_path)
- new_image_path = image_path.replace(
- "?", "%3F"
- ) # Escape ? in the path
- # Read the image file and convert to base64
- try:
- base64_image, mime_type = get_image_base64_and_mime(
- new_image_path
- )
- if base64_image is None:
- print(
- f"Warning: Could not convert image to base64: {new_image_path}"
- )
- continue
- # Create the proper image_url structure with base64 data
- processed_content.append(
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:{mime_type};base64,{base64_image}",
- "detail": "high",
- },
- }
- )
- except FileNotFoundError:
- print(f"Warning: Image file not found: {new_image_path}")
- continue
- except Exception as e:
- print(f"Warning: Error processing image {new_image_path}: {e}")
- continue
- else:
- processed_content.append(c)
- processed_message["content"] = processed_content
- processed_messages.append(processed_message)
- # Create OpenAI client with custom base URL
- client = OpenAI(api_key=api_key, base_url=server_url)
- try:
- print(f"🔍 Calling model {model}...")
- response = client.chat.completions.create(
- model=model,
- messages=processed_messages,
- max_completion_tokens=max_tokens,
- n=1,
- )
- # print(f"Received response: {response.choices[0].message}")
- # Extract the response content
- if response.choices and len(response.choices) > 0:
- return response.choices[0].message.content
- else:
- print(f"Unexpected response format: {response}")
- return None
- except Exception as e:
- print(f"Request failed: {e}")
- return None
- def send_direct_request(
- llm: Any,
- messages: list[dict[str, Any]],
- sampling_params: Any,
- ) -> Optional[str]:
- """
- Run inference on a vLLM model instance directly without using a server.
- Args:
- llm: Initialized vLLM LLM instance (passed from external initialization)
- messages: List of message dicts with role and content (OpenAI format)
- sampling_params: vLLM SamplingParams instance (initialized externally)
- Returns:
- str: Generated response text, or None if inference fails
- """
- try:
- # Process messages to handle images (convert to base64 if needed)
- processed_messages = []
- for message in messages:
- processed_message = message.copy()
- if message["role"] == "user" and "content" in message:
- processed_content = []
- for c in message["content"]:
- if isinstance(c, dict) and c.get("type") == "image":
- # Convert image path to base64 format
- image_path = c["image"]
- new_image_path = image_path.replace("?", "%3F")
- try:
- base64_image, mime_type = get_image_base64_and_mime(
- new_image_path
- )
- if base64_image is None:
- print(
- f"Warning: Could not convert image: {new_image_path}"
- )
- continue
- # vLLM expects image_url format
- processed_content.append(
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:{mime_type};base64,{base64_image}"
- },
- }
- )
- except Exception as e:
- print(
- f"Warning: Error processing image {new_image_path}: {e}"
- )
- continue
- else:
- processed_content.append(c)
- processed_message["content"] = processed_content
- processed_messages.append(processed_message)
- print("🔍 Running direct inference with vLLM...")
- # Run inference using vLLM's chat interface
- outputs = llm.chat(
- messages=processed_messages,
- sampling_params=sampling_params,
- )
- # Extract the generated text from the first output
- if outputs and len(outputs) > 0:
- generated_text = outputs[0].outputs[0].text
- return generated_text
- else:
- print(f"Unexpected output format: {outputs}")
- return None
- except Exception as e:
- print(f"Direct inference failed: {e}")
- return None
|