Source code for datastudio.operators.mllm.base

"""Base MLLM operator for model-powered filtering and rewriting."""

import gc
import json
from abc import abstractmethod
from typing import Any, Dict, List, Optional

from datastudio.utils.registry import REQUESTS, build_from_cfg

from ..core.data_item import DataItem
from ..core.operator import Operator
from ..core.result import Result


[docs] class MLLMOperator(Operator): """Base class for MLLM-powered operators. Builds requests via :class:`RequestBuilder`, calls ``model.generate()``, and delegates response parsing to subclasses via :meth:`_add_decision`. """
[docs] def __init__( self, model: Any, request_builder: Optional[Dict] = None, batch_qa: bool = False, logger: Optional[Any] = None, **kwargs, ): """ Initialize the MLLM operator. Args: model: MLLM model with generate() method. request_builder: RequestBuilder config dict. batch_qa: If True, combine all QA pairs into one request. Model should return {q0: ..., q1: ..., ...}. If False (default), each QA pair is a separate request. logger: Logger instance. """ super().__init__(logger=logger, **kwargs) if model is None: raise ValueError("model is required") self.model = model self.batch_qa = batch_qa # Build RequestBuilder from config self.request_builder = build_from_cfg(request_builder, REQUESTS) if self.request_builder is None: raise ValueError("request_builder is required and must be a valid config")
# === Main processing ===
[docs] def process(self, item: DataItem) -> Result: """Process a single item (delegates to process_batch).""" return self.process_batch([item])[0]
[docs] def process_batch(self, items: List[DataItem]) -> List[Result]: """ Process a batch of items. Flow: 1. Build requests using RequestBuilder 2. Execute model 3. Parse responses and aggregate into Results """ # Initialize results for all items results: Dict[int, Result] = { item.idx: Result(item_idx=item.idx) for item in items } # Step 1: Build requests all_requests = self._build_requests(items) if not all_requests: return [results[item.idx] for item in items] # Step 2: Execute model self._log_example_request(all_requests) payloads = [r["payload"] for r in all_requests] responses = self.model.generate(payloads) self._log_example_response(responses) self._cleanup_requests(all_requests, payloads) # Step 3: Parse responses and add decisions # Track which items were actually processed by the model processed_item_idxs = set() for req, resp in zip(all_requests, responses): item: DataItem = req["item"] qa_idx: int = req["qa_idx"] processed_item_idxs.add(item.idx) # Handle batch_qa mode vs single mode if qa_idx == -1: # Batch mode: parse all QA pairs from response for i in range(item.qa_count): parsed = self.request_builder.parse_response(resp, i, self.logger) self._add_decision(results[item.idx], i, parsed) else: # Single mode: response has q0, maps to actual qa_idx parsed = self.request_builder.parse_response(resp, 0, self.logger) self._add_decision(results[item.idx], qa_idx, parsed) # Step 4: Record which model processed each item model_name = getattr(self.model, "model", None) if model_name: for item in items: if item.idx in processed_item_idxs: item.add_model_record(self.name, model_name) return [results[item.idx] for item in items]
# === Build requests === def _build_requests(self, items: List[DataItem]) -> List[Dict]: """Build all requests for all items.""" all_requests = [] for item in items: if self._should_skip(item): continue requests = self.request_builder.build_requests(item, self.batch_qa) all_requests.extend(requests) return all_requests def _should_skip(self, item: DataItem) -> bool: """Check if this item should be skipped.""" return self.request_builder.with_image and not item.has_image @staticmethod def _cleanup_requests(all_requests: List[Dict], payloads: List[Any]) -> None: del payloads for req in all_requests: req.pop("payload", None) gc.collect() # === Abstract method for subclasses === @abstractmethod def _add_decision(self, result: Result, qa_idx: int, parsed: Dict): """ Add decision based on parsed response. Subclasses implement this to add FilterDecision or RewriteDecision. Args: result: Result to add decision to. qa_idx: QA pair index. parsed: Parsed response dict with "result", "reason", etc. May contain "parse_error": True if parsing failed. """ pass # === Logging === def _log_example_request(self, requests: List[Dict]): """Log a sample request for debugging.""" if not self.logger or not requests: return # Pick middle request idx = len(requests) // 2 request = requests[idx] payload = request["payload"] # Make loggable (replace image with placeholder) loggable = [] for item in payload: if item.get("type") == "image" and item.get("value") is not None: value = item["value"] if isinstance(value, list): placeholder = f"<List of {len(value)} PIL.Image>" else: placeholder = "<PIL.Image>" loggable.append({"type": "image", "value": placeholder}) else: loggable.append(item.copy()) self.logger.info( f"\n[{self.name}] Example request:\n" f"{json.dumps(loggable, indent=2, ensure_ascii=False)}\n" ) def _log_example_response(self, responses: List[Any]): """Log a sample response for debugging.""" if not self.logger or not responses: return idx = len(responses) // 2 response = responses[idx] try: if isinstance(response, str): response = json.loads(response) if isinstance(response, dict): self.logger.info( f"\n[{self.name}] Example response:\n" f"{json.dumps(response, indent=2, ensure_ascii=False)}\n" ) else: self.logger.info(f"\n[{self.name}] Example response:\n{response}\n") except (json.JSONDecodeError, TypeError): self.logger.warning( f"\n[{self.name}] Example response (not JSON):\n{response}\n" )