"""Async OpenAI-compatible API wrapper.
Uses aiohttp with semaphore-based concurrency control and connection pooling
to handle high-throughput batch inference against OpenAI-compatible endpoints.
"""
import asyncio
import os
import time
from typing import Any, List, Optional, Tuple
import aiohttp
import requests
from tqdm.asyncio import tqdm as async_tqdm
from datastudio.utils.registry import MODELS
from .base import BaseAPI
[docs]
@MODELS.register_module()
class OpenAIAPI(BaseAPI):
"""Async OpenAI-compatible API wrapper.
Uses a shared :class:`aiohttp.ClientSession` with connection pooling
and ``asyncio.Semaphore`` for concurrency control. Each request handles
retries independently with exponential backoff.
"""
[docs]
def __init__(
self,
model: str,
key: str = None,
temperature: float = 0.6,
top_p: float = 0.95,
top_k: int = 20,
min_p: float = 0.0,
presence_penalty: float = 1.0,
repetition_penalty: float = 1.0,
api_base: str = None,
port: int = None,
retry: int = 10,
wait: int = 3,
timeout: tuple = (30, 1800),
max_tokens: int = 16384,
thread_num: int = 8192,
return_dict: bool = False,
logger=None,
max_connections: int = None,
enable_thinking: bool = False,
**kwargs,
):
"""Initialize OpenAI-compatible API wrapper.
Args:
model: Model name, e.g., gpt-4, Qwen2-VL-72B.
key: API key (uses OPENAI_API_KEY env var if not provided).
temperature: Generation temperature.
top_p: Nucleus sampling threshold (0~1).
top_k: Top-K sampling (0 to disable). vLLM/SGLang extra param.
min_p: Minimum probability threshold. vLLM/SGLang extra param.
presence_penalty: Penalize tokens already present in the output.
repetition_penalty: Penalize repeated tokens. vLLM/SGLang extra param.
api_base: API base URL (full URL to chat/completions endpoint).
port: Port number for local deployments.
retry: Number of retry attempts on API failure.
wait: Max wait time between retries (seconds).
timeout: Request timeout as (connect_timeout, read_timeout) tuple.
max_tokens: Maximum tokens in response.
thread_num: Maximum number of concurrent requests (semaphore limit).
return_dict: Whether to parse response as dict via load_str_to_dict.
logger: Logger instance.
max_connections: Max TCP connections in the pool. Defaults to
min(thread_num, 16384).
enable_thinking: Whether to enable thinking mode for models that support it
(e.g., Qwen3 on vLLM/SGLang). Default False.
**kwargs: Additional API parameters passed to payload.
"""
super().__init__(
retry=retry,
wait=wait,
timeout=timeout,
logger=logger,
openai_format=True,
image_key_name="image_url",
thread_num=thread_num,
**kwargs,
)
self.model = model
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty
self.max_tokens = max_tokens
self.return_dict = return_dict
self.concurrency_limit = thread_num
self.enable_thinking = enable_thinking
# Connection pool size
if max_connections is None:
self.max_connections = min(thread_num, 16384)
else:
self.max_connections = max_connections
# Get API key
self.key = key or os.environ.get("OPENAI_API_KEY", "")
# Build API URL
if not api_base:
self.api_base = os.environ.get("OPENAI_API_BASE", "")
else:
if port is not None:
self.api_base = f"{api_base}:{port}/v1/chat/completions"
else:
if logger:
logger.warning("Port not specified, using api_base directly")
self.api_base = api_base
if logger:
logger.info(f"Using API URL: {self.api_base}")
logger.info(
f"Concurrency: {self.concurrency_limit}, "
f"Connections: {self.max_connections}"
)
def _build_openai_headers(self) -> dict:
"""Build HTTP headers using instance API key.
Overrides BaseAPI._build_openai_headers to use self.key directly.
Returns:
Headers dict with Content-Type and optional Authorization.
"""
return super()._build_openai_headers(self.key)
# Sampling parameter names managed explicitly (used by _build_payload)
_SAMPLING_PARAM_KEYS = (
"temperature",
"max_tokens",
"top_p",
"top_k",
"min_p",
"presence_penalty",
"repetition_penalty",
)
def _build_payload(self, message: List[dict], **kwargs) -> dict:
"""Build request payload for chat completions API.
Args:
message: Preprocessed message list (OpenAI format).
**kwargs: Override parameters.
Returns:
Request payload dict.
"""
temperature = kwargs.get("temperature", self.temperature)
max_tokens = kwargs.get("max_tokens", self.max_tokens)
top_p = kwargs.get("top_p", self.top_p)
top_k = kwargs.get("top_k", self.top_k)
min_p = kwargs.get("min_p", self.min_p)
presence_penalty = kwargs.get("presence_penalty", self.presence_penalty)
repetition_penalty = kwargs.get("repetition_penalty", self.repetition_penalty)
# Filter out known keys to avoid duplicates in payload
extra_kwargs = {
k: v for k, v in kwargs.items() if k not in self._SAMPLING_PARAM_KEYS
}
payload = {
"model": self.model,
"messages": message,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
"presence_penalty": presence_penalty,
"repetition_penalty": repetition_penalty,
**extra_kwargs,
}
# Request structured JSON output when return_dict is enabled
if self.return_dict:
payload["response_format"] = {"type": "json_object"}
# Add thinking mode config (for vLLM/SGLang with Qwen3 etc.)
if "chat_template_kwargs" not in payload:
payload["chat_template_kwargs"] = {}
if "enable_thinking" not in payload["chat_template_kwargs"]:
payload["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking
return payload
async def _request_non_streaming(
self,
session: aiohttp.ClientSession,
message: List[dict],
**kwargs,
) -> Tuple[int, Any]:
"""Make a non-streaming API request.
Args:
session: Shared aiohttp session.
message: Preprocessed message list.
**kwargs: Additional parameters.
Returns:
Tuple of (ret_code, response_content).
"""
headers = self._build_openai_headers()
payload = self._build_payload(message, **kwargs)
async with session.post(
self.api_base, json=payload, headers=headers
) as response:
resp_status = response.status
if resp_status == 200:
data = await response.json(content_type=None)
else:
error_text = await response.text()
if resp_status != 200:
return 1, f"HTTP {resp_status}: {error_text[:200]}"
# Post-process OUTSIDE connection scope
answer = self._parse_openai_response(data)
return 0, answer
async def _make_request(
self,
session: aiohttp.ClientSession,
message: List[dict],
**kwargs,
) -> Tuple[int, Any]:
"""Make an API request, delegating to the non-streaming handler."""
return await self._request_non_streaming(session, message, **kwargs)
def _calculate_retry_delay(self, attempt: int) -> float:
"""Calculate exponential backoff delay with full jitter.
Delegates to base class with default max_delay of 2.0 seconds.
Args:
attempt: Current retry attempt (0-indexed).
Returns:
Delay in seconds.
"""
return super()._calculate_retry_delay(attempt, max_delay=2.0)
async def _process_single_request(
self,
session: aiohttp.ClientSession,
semaphore: asyncio.Semaphore,
processed_msg: Any,
all_kwargs: dict,
pbar: Optional[async_tqdm] = None,
error_counter: Optional[dict] = None,
) -> str:
"""Process a single request with concurrency control and retry.
Args:
session: Shared aiohttp session.
semaphore: Concurrency limiter.
processed_msg: Preprocessed message (OpenAI format).
all_kwargs: Additional generation arguments.
pbar: Optional progress bar.
error_counter: Optional error tracking dict.
Returns:
Response content string or fail_msg.
"""
ret_code = None
response_struct = None
# Per-attempt timeout
_, read_timeout = self.timeout
per_attempt_timeout = read_timeout + 60
connect_timeout, _ = self.timeout
async with semaphore:
for i in range(self.retry):
try:
if error_counter is not None:
error_counter["active"] += 1
pbar.set_postfix(**error_counter)
t0 = asyncio.get_event_loop().time()
ret_code, response_struct = await asyncio.wait_for(
self._make_request(session, processed_msg, **all_kwargs),
timeout=per_attempt_timeout,
)
if error_counter is not None:
error_counter["active"] -= 1
if ret_code == 0 and response_struct and response_struct != "":
if pbar:
if error_counter is not None:
error_counter["ok"] += 1
pbar.set_postfix(**error_counter)
pbar.update(1)
return response_struct
else:
raise Exception(f"Invalid response: {response_struct}")
except aiohttp.ServerTimeoutError:
elapsed = asyncio.get_event_loop().time() - t0
ret_code = 1
if elapsed < connect_timeout + 5:
timeout_type = (
f"connect timeout (sock_connect={connect_timeout}s)"
)
else:
timeout_type = (
f"read timeout (sock_read={per_attempt_timeout - 60}s)"
)
response_struct = f"aiohttp {timeout_type}"
if error_counter is not None:
error_counter["active"] -= 1
error_counter["timeout"] += 1
pbar.set_postfix(**error_counter)
if self.logger and i + 1 > self.retry * 2 // 3:
self.logger.warning(
f"Attempt {i+1}/{self.retry} {timeout_type} "
f"(elapsed {elapsed:.1f}s). "
f"Server may be overloaded or unreachable."
)
except asyncio.TimeoutError:
elapsed = asyncio.get_event_loop().time() - t0
ret_code = 1
response_struct = "Request timeout"
if error_counter is not None:
error_counter["active"] -= 1
error_counter["timeout"] += 1
pbar.set_postfix(**error_counter)
if self.logger and i + 1 > self.retry * 2 // 3:
self.logger.warning(
f"Attempt {i+1}/{self.retry} timed out "
f"after {elapsed:.1f}s (limit {per_attempt_timeout}s)"
)
except aiohttp.ClientError as e:
ret_code = 1
response_struct = f"ClientError: {type(e).__name__}"
if error_counter is not None:
error_counter["active"] -= 1
error_counter["error"] += 1
pbar.set_postfix(**error_counter)
except Exception as e:
if error_counter is not None:
error_counter["active"] -= 1
error_counter["error"] += 1
pbar.set_postfix(**error_counter)
if i + 1 > self.retry * 2 // 3 and self.logger:
self.logger.warning(
f"Attempt {i+1}/{self.retry} failed, "
f"RetCode: {ret_code}, "
f"Error: {type(e).__name__}: {str(e)[:200]}"
)
if i < self.retry - 1:
delay = self._calculate_retry_delay(i)
await asyncio.sleep(delay)
if error_counter is not None:
error_counter["fail"] += 1
pbar.set_postfix(**error_counter)
if self.logger:
self.logger.warning(
f"Failed after {self.retry} attempts, "
f"RetCode: {ret_code}, Response: {response_struct}"
)
if pbar:
pbar.update(1)
return self.fail_msg
async def _generate_async(
self,
processed_messages: List[Any],
all_kwargs: dict,
) -> List[str]:
"""Run async generation over all messages.
Creates a shared :class:`aiohttp.ClientSession` with connection
pooling and fires all requests via ``asyncio.gather``.
Args:
processed_messages: List of pre-processed messages (OpenAI format).
all_kwargs: Additional generation arguments.
Returns:
List of response strings.
"""
# Configure timeout
connect_timeout, read_timeout = self.timeout
timeout = aiohttp.ClientTimeout(
total=None,
sock_connect=connect_timeout,
sock_read=read_timeout,
)
# Configure connector with connection pooling
pool_size = min(self.max_connections, self.concurrency_limit)
connector = aiohttp.TCPConnector(
limit=pool_size,
ttl_dns_cache=300,
force_close=False,
keepalive_timeout=300,
)
total = len(processed_messages)
if self.logger:
self.logger.info(
f"[OpenAI API] Processing {total} requests with concurrency={self.concurrency_limit}, "
f"connections={pool_size}"
)
error_counter = {"ok": 0, "active": 0, "timeout": 0, "error": 0, "fail": 0}
pbar = async_tqdm(total=total, desc="Processing API Requests", leave=True)
pbar.set_postfix(ok=0, active=0, timeout=0, error=0, fail=0)
# Warm-up: gradually increase concurrency to avoid TCP connection storm
initial_concurrency = min(64, self.concurrency_limit)
need_warmup = (
self.concurrency_limit > initial_concurrency and total > initial_concurrency
)
if need_warmup:
semaphore = asyncio.Semaphore(initial_concurrency)
else:
semaphore = asyncio.Semaphore(self.concurrency_limit)
async with aiohttp.ClientSession(
timeout=timeout, connector=connector
) as session:
if need_warmup:
extra_slots = self.concurrency_limit - initial_concurrency
async def _ramp_up_semaphore():
released = 0
batch = initial_concurrency
while released < extra_slots:
await asyncio.sleep(1.0)
to_release = min(batch, extra_slots - released)
for _ in range(to_release):
semaphore.release()
released += to_release
batch = (
min(batch * 2, extra_slots - released)
if released < extra_slots
else 0
)
ramp_task = asyncio.create_task(_ramp_up_semaphore())
tasks = [
self._process_single_request(
session, semaphore, msg, all_kwargs, pbar, error_counter
)
for msg in processed_messages
]
results = await asyncio.gather(*tasks, return_exceptions=True)
if need_warmup:
ramp_task.cancel()
try:
await ramp_task
except asyncio.CancelledError:
pass
pbar.refresh()
pbar.close()
# Handle any exceptions that slipped through
final_results = []
for r in results:
if isinstance(r, Exception):
if self.logger:
self.logger.warning(f"Task exception: {r}")
final_results.append(self.fail_msg)
elif r is None:
final_results.append(self.fail_msg)
else:
final_results.append(r)
return final_results
[docs]
def generate(self, messages: List[Any], **kwargs) -> List[str]:
"""Main method to generate responses.
Uses asyncio.run to execute the async generation pipeline.
Pre-encodes images before async processing.
Args:
messages: List of input messages.
**kwargs: Additional generation arguments.
Returns:
List of response strings, in same order as input.
"""
if not messages:
return []
# Phase 1: image pre-encoding
t_encode = time.time()
self._pre_encode_images(messages)
encode_elapsed = time.time() - t_encode
# Phase 2: request preprocessing
t_preproc = time.time()
processed_messages = []
for msg in messages:
processed_messages.append(self.pre_process(msg))
preproc_elapsed = time.time() - t_preproc
# Phase 3: drop redundant original buffers
t_cleanup = time.time()
self._clear_message_buffers(messages)
cleanup_elapsed = time.time() - t_cleanup
if self.logger:
self.logger.info(
"[OpenAIAPI] prepare phases: "
f"pre_encode={encode_elapsed:.2f}s, "
f"preprocess={preproc_elapsed:.2f}s, "
f"clear_message_buffers={cleanup_elapsed:.2f}s"
)
all_kwargs = dict(self.default_kwargs)
all_kwargs.update(kwargs)
# Run async pipeline
try:
# Check if there's already a running event loop
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
# We're inside an existing event loop (e.g., Jupyter notebook)
# Use nest_asyncio or create a new thread
import threading
self.logger.warning("Running in existing event loop, using thread")
result = [None]
def _run():
result[0] = asyncio.run(
self._generate_async(processed_messages, all_kwargs) # noqa: F821
)
thread = threading.Thread(target=_run)
thread.start()
thread.join()
final = result[0]
else:
final = asyncio.run(self._generate_async(processed_messages, all_kwargs))
# Release base64 payloads before next batch
del processed_messages
self._collect_memory(trim=True)
return final
[docs]
def generate_inner(self, inputs, **kwargs):
"""Sync version of generate for single request (BaseAPI compatibility).
Args:
inputs: Preprocessed message list (OpenAI format).
**kwargs: Additional parameters.
Returns:
Tuple of (ret_code, answer, response).
"""
temperature = kwargs.get("temperature", self.temperature)
max_tokens = kwargs.get("max_tokens", self.max_tokens)
top_p = kwargs.get("top_p", self.top_p)
top_k = kwargs.get("top_k", self.top_k)
min_p = kwargs.get("min_p", self.min_p)
presence_penalty = kwargs.get("presence_penalty", self.presence_penalty)
repetition_penalty = kwargs.get("repetition_penalty", self.repetition_penalty)
headers = self._build_openai_headers()
extra_kwargs = {
k: v for k, v in kwargs.items() if k not in self._SAMPLING_PARAM_KEYS
}
payload = {
"model": self.model,
"messages": inputs,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
"presence_penalty": presence_penalty,
"repetition_penalty": repetition_penalty,
**extra_kwargs,
}
if self.return_dict:
payload["response_format"] = {"type": "json_object"}
# Add thinking mode config (for vLLM/SGLang with Qwen3 etc.)
if "chat_template_kwargs" not in payload:
payload["chat_template_kwargs"] = {}
if "enable_thinking" not in payload["chat_template_kwargs"]:
payload["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking
response = requests.post(
self.api_base,
headers=headers,
json=payload,
timeout=self.timeout,
)
try:
answer = self._parse_openai_response(response.json())
except Exception as e:
return 1, e, response
return 0, answer, response
[docs]
def shutdown(self):
"""Shutdown and release resources.
Session is created and destroyed per generate() call.
"""
pass