Source code for datastudio.operators.mllm.rewriter

"""MLLM-based rewriter operator."""

from typing import Any, Callable, Dict, List, Optional

from datastudio.utils.registry import OPERATORS

from ..core.data_item import DataItem
from ..core.result import Result
from .base import MLLMOperator


[docs] @OPERATORS.register_module() class MLLMRewriter(MLLMOperator): """Rewrite data using MLLM. Supports both structured (JSON dict) and plain text responses. Set ``key_templates=None`` in the request builder for plain text mode. Example config:: request_builder = dict( type="RequestBuilder", prompt="prompts/rewriter/xxx.txt", key_templates={"result": "q{idx}_answer"}, with_image=True, with_answer=True, ) """
[docs] def __init__( self, model: Any, request_builder: Optional[Dict] = None, rewrite_type: str = "answer", batch_qa: bool = False, logger: Optional[Any] = None, **kwargs, ): """ Initialize the MLLM rewriter. Args: model: MLLM model instance. request_builder: RequestBuilder config dict. rewrite_type: What to rewrite - "answer" or "question". batch_qa: If True, combine all QA pairs into one request. logger: Logger instance. """ super().__init__( model=model, request_builder=request_builder, batch_qa=batch_qa, logger=logger, **kwargs, ) self.rewrite_type = rewrite_type if logger: logger.info(f"[{self.name}] Initialized (rewrite_type={rewrite_type})")
def _add_decision(self, result: Result, qa_idx: int, parsed: Dict): """ Add rewrite decision based on parsed response. Args: result: Result to add decision to. qa_idx: QA pair index. parsed: Parsed response with "result". """ # Handle parse error - skip rewriting if parsed.get("parse_error"): if self.logger: self.logger.warning( f"[{self.name}] Parse error for qa_idx={qa_idx}, skipping rewrite" ) return content = parsed.get("result") if content is None: return if self.rewrite_type == "answer": result.add_rewrite(qa_idx, new_answer=content, message="rewritten answer") else: result.add_rewrite( qa_idx, new_question=content, message="rewritten question" )
[docs] @OPERATORS.register_module() class SelectiveMLLMRewriter(MLLMRewriter): """MLLM rewriter that only processes items matching a condition. Example:: rewriter = SelectiveMLLMRewriter( model=model, request_builder=dict(type="RequestBuilder", prompt="prompts/translate.txt"), should_rewrite_fn=is_mixed_language, ) """
[docs] def __init__( self, should_rewrite_fn: Optional[Callable[[str, str], bool]] = None, **kwargs, ): """ Initialize selective rewriter. Args: should_rewrite_fn: Function(question, answer) -> bool. Returns True if this QA should be rewritten. **kwargs: Arguments passed to MLLMRewriter. """ super().__init__(**kwargs) self.should_rewrite_fn = should_rewrite_fn
def _build_requests(self, items: List[DataItem]) -> List[Dict]: """Build requests with only QA pairs needing rewrite.""" all_requests = [] for item in items: if self._should_skip(item): continue # Collect QA indices that need rewriting qa_indices = [ qa.idx for qa in item.qa_pairs if self._should_rewrite(qa.question, qa.answer) ] if not qa_indices: continue # Use RequestBuilder's selective method for proper image handling requests = self.request_builder.build_requests_selective(item, qa_indices) all_requests.extend(requests) return all_requests def _should_rewrite(self, question: str, answer: str) -> bool: """Check if this QA pair should be rewritten.""" if self.should_rewrite_fn is None: return True return self.should_rewrite_fn(question, answer)