Source code for datastudio.models.mp_openai_api

"""Multiprocessing OpenAI-compatible API wrapper for large-scale workloads.

Wraps :class:`OpenAIAPI` with ``fork``-based multiprocessing so that
CPU-bound image encoding and JSON serialization run in true parallelism
across cores, each worker owning an independent async event loop.

Memory is shared via fork + Copy-on-Write: the main process stores
pre-encoded messages in a module-level global, and forked workers read
their slices without copying (N workers ≈ 1x memory).
"""

import math
import multiprocessing as mp
import os
from typing import Any, List, Optional

from tqdm import tqdm

from datastudio.utils.registry import MODELS

from .base import BaseAPI
from .openai_api import OpenAIAPI

# Module-level globals shared with forked workers via Copy-on-Write.
# Workers read these without modification, so physical pages stay shared.
_SHARED_MESSAGES: Optional[List[Any]] = None
_SHARED_MODEL_KWARGS: Optional[dict] = None


def _init_shared_data(messages: List[Any], model_kwargs: dict):
    """Set module-level globals before forking workers.

    Must be called in the main process before Pool creation.

    Args:
        messages: The full pre-encoded messages list.
        model_kwargs: Dict of OpenAIAPI constructor arguments.
    """
    global _SHARED_MESSAGES, _SHARED_MODEL_KWARGS
    _SHARED_MESSAGES = messages
    _SHARED_MODEL_KWARGS = model_kwargs


def _clear_shared_data():
    """Clear module-level globals after workers finish."""
    global _SHARED_MESSAGES, _SHARED_MODEL_KWARGS
    _SHARED_MESSAGES = None
    _SHARED_MODEL_KWARGS = None


def _worker_fn(args):
    """Worker function executed in each forked subprocess.

    Reads its message slice from the inherited ``_SHARED_MESSAGES`` global
    (zero-copy via COW) and runs an independent :class:`OpenAIAPI` instance.

    Args:
        args: Tuple of (worker_id, start_idx, end_idx, worker_concurrency).

    Returns:
        List of response strings for the slice ``[start_idx, end_idx)``.
    """
    worker_id, start_idx, end_idx, worker_concurrency = args

    if start_idx >= end_idx:
        return []

    # Read slice from inherited global — no copy, COW shared pages
    messages_slice = _SHARED_MESSAGES[start_idx:end_idx]

    # Build per-worker OpenAIAPI from inherited kwargs
    model_kwargs = dict(_SHARED_MODEL_KWARGS)
    model_kwargs["thread_num"] = worker_concurrency
    # Suppress per-worker logging to avoid noisy output
    model_kwargs.pop("logger", None)

    api = OpenAIAPI(**model_kwargs)
    results = api.generate(messages_slice)
    api.shutdown()
    return results


[docs] @MODELS.register_module() class MPOpenAIAPI(BaseAPI): """Multiprocessing OpenAI-compatible API using fork + COW shared memory. Distributes requests across multiple worker processes, each running an independent :class:`OpenAIAPI` with its own async event loop. Designed for workloads where single-process async saturates CPU on image encoding / JSON serialization, total request count is large, and payload size makes pickle-based IPC infeasible. Memory model: Pre-encoded messages are stored in a module-level global before ``fork()``. Workers inherit the parent's address space via Copy-on-Write and only read shared data, so N workers ≈ 1x memory. """
[docs] def __init__( self, model: str, key: str = None, temperature: float = 0.6, top_p: float = 0.95, top_k: int = 20, min_p: float = 0.0, presence_penalty: float = 1.0, repetition_penalty: float = 1.0, api_base: str = None, port: int = None, retry: int = 10, wait: int = 3, timeout: tuple = (30, 1800), max_tokens: int = 16384, thread_num: int = 8192, return_dict: bool = False, logger=None, max_connections: int = None, num_workers: int = None, worker_concurrency: int = None, enable_thinking: bool = False, **kwargs, ): """Initialize multiprocessing OpenAI API wrapper. Args: model: Model name, e.g., gpt-4, Qwen2-VL-72B. key: API key (uses OPENAI_API_KEY env var if not provided). temperature: Generation temperature. top_p: Nucleus sampling threshold (0~1). top_k: Top-K sampling (0 to disable). vLLM/SGLang extra param. min_p: Minimum probability threshold. vLLM/SGLang extra param. presence_penalty: Penalize tokens already present in the output. repetition_penalty: Penalize repeated tokens. vLLM/SGLang extra param. api_base: API base URL (full URL to chat/completions endpoint). port: Port number for local deployments. retry: Number of retry attempts on API failure. wait: Max wait time between retries (seconds). timeout: Request timeout as (connect_timeout, read_timeout) tuple. max_tokens: Maximum tokens in response. thread_num: Total maximum concurrent requests across all workers. return_dict: Whether to parse response as dict. logger: Logger instance. max_connections: Max TCP connections per worker. Defaults to min(worker_concurrency, 16384). num_workers: Number of worker processes. Defaults to min(cpu_count, 8). Set based on available CPU cores. worker_concurrency: Async concurrency limit per worker. Defaults to thread_num // num_workers. enable_thinking: Whether to enable thinking mode for models that support it (e.g., Qwen3 on vLLM/SGLang). Default False. **kwargs: Additional API parameters passed to payload. """ super().__init__( retry=retry, wait=wait, timeout=timeout, logger=logger, openai_format=True, image_key_name="image_url", thread_num=thread_num, **kwargs, ) self.model = model self.temperature = temperature self.top_p = top_p self.top_k = top_k self.min_p = min_p self.presence_penalty = presence_penalty self.repetition_penalty = repetition_penalty self.max_tokens = max_tokens self.return_dict = return_dict self.enable_thinking = enable_thinking # Resolve API key early so workers don't need env vars self.key = key or os.environ.get("OPENAI_API_KEY", "") # Store API URL construction params self.api_base = api_base self.port = port # Build resolved API URL (same logic as OpenAIAPI) if not api_base: self._resolved_api_base = os.environ.get("OPENAI_API_BASE", "") else: if port is not None: self._resolved_api_base = f"{api_base}:{port}/v1/chat/completions" else: self._resolved_api_base = api_base # Multiprocessing configuration cpu_count = os.cpu_count() or 4 self.num_workers = num_workers or min(cpu_count, 8) # Per-worker concurrency: split total concurrency across workers if worker_concurrency is not None: self.worker_concurrency = worker_concurrency else: self.worker_concurrency = max(1, thread_num // self.num_workers) # Per-worker max connections if max_connections is not None: self.worker_max_connections = max_connections else: self.worker_max_connections = min(self.worker_concurrency, 16384) if logger: logger.info( f"MPOpenAIAPI initialized: " f"workers={self.num_workers}, " f"concurrency_per_worker={self.worker_concurrency}, " f"total_concurrency={self.num_workers * self.worker_concurrency}" ) logger.info(f"Using API URL: {self._resolved_api_base}")
def _build_model_kwargs(self) -> dict: """Build kwargs dict to reconstruct OpenAIAPI in worker processes. Returns: Dict of constructor arguments for OpenAIAPI. """ kwargs = { "model": self.model, "key": self.key, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, "min_p": self.min_p, "presence_penalty": self.presence_penalty, "repetition_penalty": self.repetition_penalty, "api_base": self._resolved_api_base, "retry": self.retry, "wait": self.wait, "timeout": self.timeout, "max_tokens": self.max_tokens, "thread_num": self.worker_concurrency, "return_dict": self.return_dict, "max_connections": self.worker_max_connections, "enable_thinking": self.enable_thinking, } kwargs.update(self.default_kwargs) return kwargs
[docs] def generate(self, messages: List[Any], **kwargs) -> List[str]: """Generate responses using fork-based multiprocessing. Args: messages: List of input messages. **kwargs: Additional generation arguments. Returns: List of response strings, in same order as input. """ if not messages: return [] total = len(messages) # Pre-encode PIL images to base64 so they are fork-safe and COW-shared if self.logger: self.logger.info(f"Pre-encoding images for {total} messages...") self._pre_encode_images(messages) # Store in module-level global for COW sharing after fork model_kwargs = self._build_model_kwargs() model_kwargs.update(kwargs) _init_shared_data(messages, model_kwargs) # Compute index ranges for each worker num_workers = min(self.num_workers, total) chunk_size = math.ceil(total / num_workers) worker_args = [] for worker_id in range(num_workers): start = worker_id * chunk_size end = min(start + chunk_size, total) worker_args.append((worker_id, start, end, self.worker_concurrency)) if self.logger: self.logger.info( f"Forking {num_workers} workers for {total} messages " f"(~{chunk_size}/worker, concurrency={self.worker_concurrency}/worker)" ) # Fork workers (safe: no threads/event loops exist yet in main process) ctx = mp.get_context("fork") all_results = [] try: with ctx.Pool(processes=num_workers) as pool: with tqdm(total=total, desc="MPOpenAIAPI Progress") as pbar: for chunk_results in pool.imap(_worker_fn, worker_args): all_results.extend(chunk_results) pbar.update(len(chunk_results)) finally: # Always clear shared data to release memory _clear_shared_data() if self.logger: fail_count = sum(1 for r in all_results if r == self.fail_msg) self.logger.info( f"Completed: {total} requests, " f"{total - fail_count} succeeded, {fail_count} failed" ) return all_results
[docs] def generate_inner(self, inputs, **kwargs): """Sync version for single request (BaseAPI compatibility). Delegates to a temporary OpenAIAPI instance. Args: inputs: Preprocessed message list (OpenAI format). **kwargs: Additional parameters. Returns: Tuple of (ret_code, answer, response). """ model_kwargs = self._build_model_kwargs() api = OpenAIAPI(**model_kwargs) return api.generate_inner(inputs, **kwargs)
[docs] def shutdown(self): """Shutdown and release resources. Workers are ephemeral (created per generate() call), nothing to clean up. """ pass