"""Base API class for model inference.
Provides common functionality shared by all API backends:
message preprocessing, image encoding, retry logic, and parallel
request processing.
"""
import concurrent.futures
import copy as cp
import ctypes
import gc
import random as rd
import time
from abc import abstractmethod
try:
_libc = ctypes.CDLL("libc.so.6")
_malloc_trim = _libc.malloc_trim
except (OSError, AttributeError):
def _malloc_trim(x):
return 0
import numpy as np
from PIL import Image
from tqdm import tqdm
from datastudio.utils.json_parser import load_str_to_dict
from datastudio.utils.vision import encode_image_to_base64, parse_file
[docs]
class BaseAPI:
"""Base API class providing common functionality for all API models.
This class handles:
- Message preprocessing and formatting
- Image encoding (base64)
- Retry logic with exponential backoff
- Parallel request processing
"""
allowed_types = ["text", "image", "image_url", "system"]
[docs]
def __init__(
self,
retry=10,
wait=5,
timeout=(30, 1800),
logger=None,
thread_num=384,
fail_msg="Failed to obtain answer via API.",
openai_format=True,
image_key_name="image",
return_dict=True,
use_system_proxy=False,
**kwargs,
):
"""Initialize the base API.
Args:
retry: Number of retry attempts on API failure.
wait: Wait time between retries (seconds).
timeout: Request timeout as (connect_timeout, read_timeout) tuple.
logger: Logger instance.
thread_num: Number of parallel threads.
fail_msg: Message returned on failure.
openai_format: Whether to use OpenAI message format.
image_key_name: Key name for image data.
return_dict: Whether to parse response as dict.
use_system_proxy: Whether to use system proxy.
**kwargs: Additional arguments passed to generate_inner.
"""
self.retry = retry
self.wait = wait
self.timeout = timeout
self.logger = logger
self.fail_msg = fail_msg
self.thread_num = thread_num
self.default_kwargs = kwargs if kwargs else {}
self.image_key_name = image_key_name
self.openai_format = openai_format
self.return_dict = return_dict
self.use_system_proxy = use_system_proxy
@staticmethod
def _collect_memory(trim: bool = False):
"""Run garbage collection and optionally trim glibc malloc arenas.
Args:
trim: If True, call malloc_trim(0) to return freed memory to the OS.
"""
gc.collect()
if trim:
_malloc_trim(0)
@staticmethod
def _clear_message_buffers(messages):
"""Clear all message buffers to release memory.
Args:
messages: List of message objects to clear.
"""
for msg in messages:
if isinstance(msg, list):
msg.clear()
messages.clear()
[docs]
@abstractmethod
def generate_inner(self, inputs, **kwargs):
"""Generate response for given inputs. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement generate_inner method")
[docs]
def encode_image_directly(self, img_item):
"""Encode image directly without caching.
Args:
img_item: PIL Image object or other image data.
Returns:
Base64 encoded string or original data.
"""
if isinstance(img_item, Image.Image):
# Encode directly without caching
try:
return encode_image_to_base64(img_item)
except Exception as e:
if self.logger:
self.logger.warning(f"Image encoding failed: {str(e)}")
raise
elif isinstance(img_item, str) and img_item.startswith("data:image"):
# Already in base64 format, return directly
return img_item
else:
# Other cases (e.g., file path), open and encode
img = Image.open(img_item)
return self.encode_image_directly(img)
[docs]
def check_content(self, msgs):
"""Check input content type.
Args:
msgs: Raw input messages.
Returns:
str: Message type (str/dict/liststr/listdict).
"""
if isinstance(msgs, str):
return "str"
if isinstance(msgs, dict):
return "dict"
if isinstance(msgs, list):
types = [self.check_content(m) for m in msgs]
if all(t == "str" for t in types):
return "liststr"
if all(t == "dict" for t in types):
return "listdict"
return "unknown"
[docs]
def preproc_content(self, inputs):
"""Convert raw input messages to unified dict list format.
Args:
inputs: Raw input.
Returns:
list: Processed input message list.
"""
if self.check_content(inputs) == "str":
return [dict(type="text", value=inputs)]
elif self.check_content(inputs) == "dict":
assert "type" in inputs and "value" in inputs
return [inputs]
elif self.check_content(inputs) == "liststr":
res = []
for s in inputs:
mime, pth = parse_file(s, self.image_key_name)
if mime is None or mime == "unknown":
res.append(dict(type="text", value=s))
else:
res.append(dict(type=mime.split("/")[0], value=pth))
return res
elif self.check_content(inputs) == "listdict":
for item in inputs:
assert "type" in item and "value" in item
# Skip file parsing for system prompts (plain text)
if item["type"] == "system":
continue
mime, s = parse_file(item["value"], self.image_key_name)
if mime is None:
assert item["type"] == "text", item["value"]
else:
assert item["type"] in ["image", "image_url"]
item["type"] = mime
item["value"] = s
return inputs
else:
return None
def _parse_openai_response(self, data: dict) -> str:
"""Parse OpenAI-format API response content.
Args:
data: Parsed JSON response from OpenAI API.
Returns:
Extracted content string, with optional reasoning_content formatting.
"""
message_obj = data.get("choices", [{}])[0].get("message", {})
content = message_obj.get("content", self.fail_msg)
reasoning_content = message_obj.get("reasoning_content")
if self.return_dict:
return load_str_to_dict(content)
if reasoning_content:
return (
f"<think>\n{reasoning_content.strip()}\n</think>\n\n{content.strip()}"
)
return content
def _build_openai_headers(self, api_key: str) -> dict:
"""Build HTTP headers for OpenAI-compatible API request.
Args:
api_key: API key for authorization. If empty, no Authorization header.
Returns:
Headers dict with Content-Type and optional Authorization.
"""
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def _calculate_retry_delay(self, attempt: int, max_delay: float = 2.0) -> float:
"""Calculate exponential backoff delay with full jitter.
Args:
attempt: Current retry attempt (0-indexed).
max_delay: Maximum delay cap in seconds.
Returns:
Delay in seconds (uniform random in ``[0, min(base * 2^attempt, max_delay)]``).
"""
base_delay = 0.1
exp_delay = min(base_delay * (2**attempt), min(self.wait, max_delay))
return rd.random() * exp_delay
[docs]
def process_single_message(self, message, all_kwargs):
"""Process a single message with retry logic.
Args:
message: Single processed message.
all_kwargs: Arguments passed to generate_inner.
Returns:
Response content or fail_msg on failure.
"""
try:
message = self.pre_process(message)
except Exception as e:
if self.logger:
self.logger.warning(f"Preprocessing failed: {str(e)}")
return self.fail_msg
# Initialize variables to avoid UnboundLocalError
ret_code = None
response_struct = None
response = None
for i in range(self.retry):
try:
ret_code, response_struct, response = self.generate_inner(
message, **all_kwargs
)
if ret_code == 0 and response_struct and response_struct != "":
return response_struct
else:
raise Exception("Invalid response")
except Exception as e:
# Use exponential backoff with jitter (faster than fixed wait)
if i < self.retry - 1:
delay = self._calculate_retry_delay(i)
time.sleep(delay)
if i + 1 > self.retry * 2 // 3 and self.logger:
self.logger.warning(
f"Attempt {i+1}/{self.retry} failed, RetCode: {ret_code}, "
f"Response: {response_struct}, Error: {type(e).__name__}: {str(e)[:200]}"
)
response_str = (
response.text
if hasattr(response, "text") and response
else str(response) if response else "No response"
)
if self.logger:
self.logger.warning(
f"Failed after {self.retry} attempts, RetCode: {ret_code}, Response: {response_str}"
)
return self.fail_msg
[docs]
def pre_process(self, msg):
"""Validate and preprocess raw input messages into API-ready format.
Args:
msg: Raw input message (str, dict, or list).
Returns:
list: Preprocessed message list in OpenAI format.
Raises:
AssertionError: If input type is unsupported.
"""
assert self.check_content(msg) in [
"str",
"dict",
"liststr",
"listdict",
], f"Unsupported input type: {msg}"
msg = self.preproc_content(msg)
assert msg is not None and self.check_content(msg) == "listdict"
for item in msg:
assert (
item["type"] in self.allowed_types
), f'Unsupported input type: {item["type"]}'
# Prepare input messages
processed_msg, _ = self.prepare_inputs(msg)
return processed_msg
def _pre_encode_images(self, messages):
"""Pre-encode all PIL images to base64 strings in parallel.
Encodes images before multi-threaded API calls to avoid PIL
thread-safety issues. Uses ``id``-based deduplication to encode
each unique image object only once.
Args:
messages: List of message dicts containing potential PIL images.
Modified in place (PIL objects replaced with base64 strings).
"""
# Collect unique images
images_to_encode = {} # id -> image object
for msg in messages:
if not isinstance(msg, list):
continue
for item in msg:
if not isinstance(item, dict):
continue
if item.get("type") in ("image", "image_url") and "value" in item:
value = item["value"]
if isinstance(value, list):
for img in value:
if isinstance(img, Image.Image):
images_to_encode[id(img)] = img
elif isinstance(value, Image.Image):
images_to_encode[id(value)] = value
if not images_to_encode:
return
if self.logger:
self.logger.info(f"Pre-encoding {len(images_to_encode)} images...")
# Encode images in parallel
img_items = list(images_to_encode.items())
encoded_cache = {}
def encode_task(item):
img_id, img = item
return img_id, self.encode_image_directly(img)
try:
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
futures = {
executor.submit(encode_task, item): item for item in img_items
}
with tqdm(
total=len(img_items), desc="Encoding images", leave=False
) as pbar:
for future in concurrent.futures.as_completed(futures):
img_id, encoded = future.result()
encoded_cache[img_id] = encoded
pbar.update(1)
except RuntimeError:
# Fallback: if thread creation still fails, encode sequentially
if self.logger:
self.logger.warning(
"ThreadPoolExecutor failed (OS thread limit?), "
"falling back to sequential image encoding."
)
for item in tqdm(
img_items, desc="Encoding images (sequential)", leave=False
):
img_id, encoded = encode_task(item)
encoded_cache[img_id] = encoded
finally:
gc.collect()
if self.logger:
self.logger.info(
"Image byte encoding complete, replacing image references in payloads..."
)
# Replace PIL images with encoded strings
# Note: Do NOT close PIL images here - they may be shared references
# reused by later pipeline stages. They will be garbage collected when
# no longer referenced.
for msg in messages:
if not isinstance(msg, list):
continue
for item in msg:
if not isinstance(item, dict):
continue
if item.get("type") in ("image", "image_url") and "value" in item:
value = item["value"]
if isinstance(value, list):
new_values = []
for img in value:
if isinstance(img, Image.Image):
encoded = encoded_cache.get(id(img), img)
new_values.append(encoded)
else:
new_values.append(img)
item["value"] = new_values
elif isinstance(value, Image.Image):
encoded = encoded_cache.get(id(value), value)
item["value"] = encoded
# Clear references
images_to_encode.clear()
encoded_cache.clear()
if self.logger:
self.logger.info("Image encoding complete!")
[docs]
def generate(self, messages, **kwargs):
"""Main method to generate responses.
Args:
messages: Input messages (list of message dicts).
**kwargs: Additional arguments.
Returns:
list: Generated responses for each message.
"""
all_kwargs = cp.deepcopy(self.default_kwargs)
all_kwargs.update(kwargs)
# Handle empty messages list
if not messages:
return []
# Pre-encode images (thread-safe)
self._pre_encode_images(messages)
# Use dedicated thread pool to avoid resource leaks
max_workers = min(self.thread_num, len(messages))
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks with retry on thread creation failure
futures = {}
for i, msg in enumerate(messages):
while True:
try:
future = executor.submit(
self.process_single_message, msg, all_kwargs
)
futures[future] = i
break
except RuntimeError as e:
if "can't start new thread" in str(e):
# Wait for some tasks to complete before submitting more
time.sleep(0.1)
else:
raise
# Use tqdm to display progress
results = [None] * len(futures)
with tqdm(total=len(futures), desc="Processing API Requests") as pbar:
for future in concurrent.futures.as_completed(futures):
index = futures[future]
results[index] = future.result()
pbar.update(1)
# Release base64 data before next batch
self._clear_message_buffers(messages)
self._collect_memory(trim=True)
return results
[docs]
def shutdown(self):
"""Shutdown and release resources."""
pass