"""Request builder for MLLM operators."""
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from datastudio.utils.registry import REQUESTS
if TYPE_CHECKING:
from ..core.data_item import DataItem
def copy_image(image: Any) -> Any:
"""Return image reference directly (images are pre-encoded to base64)."""
return image
[docs]
@REQUESTS.register_module()
class RequestBuilder:
"""
MLLM request builder.
Responsibilities:
1. Store request_builder configuration
2. Build requests (build_request, build_requests)
3. Format content (format_content)
4. Parse model responses (parse_response)
Config format:
request_builder = dict(
type="RequestBuilder",
prompt="prompts/filter/xxx.txt",
system_prompt="prompts/grounding/grounding_system.txt",
key_templates={"result": "q{idx}", "reason": "q{idx}_reason"},
with_image=True,
with_answer=False,
)
Attributes:
prompt: Prompt text or path to .txt file (user message).
system_prompt: System prompt text or path to .txt file (system message).
key_templates: Dict mapping field names to key templates with {idx}.
with_image/with_question/with_answer/with_original: Request flags.
"""
# Class-level defaults (can be overridden by subclass or instance)
prompt: str = ""
system_prompt: str = ""
key_templates: Dict[str, str] = {"result": "q{idx}", "reason": "q{idx}_reason"}
with_image: bool = True
with_question: bool = True
with_answer: bool = False
with_original: bool = False
# Sentinel to distinguish "not provided" from "explicitly None"
_NOT_PROVIDED = object()
[docs]
def __init__(
self,
prompt: Optional[str] = None,
system_prompt: Optional[str] = None,
key_templates: Union[Dict[str, str], None, object] = _NOT_PROVIDED,
with_image: Optional[bool] = None,
with_question: Optional[bool] = None,
with_answer: Optional[bool] = None,
with_original: Optional[bool] = None,
**kwargs,
):
"""
Initialize request builder.
Instance values override class defaults.
For key_templates, explicitly passing None disables JSON parsing.
"""
# prompt: None or "" both mean no prompt
if prompt:
self.prompt = prompt
if system_prompt:
self.system_prompt = system_prompt
# key_templates: need to distinguish "not provided" vs "explicitly None"
if key_templates is not self._NOT_PROVIDED:
self.key_templates = key_templates
# Boolean flags: only override if provided (not None)
if with_image is not None:
self.with_image = with_image
if with_question is not None:
self.with_question = with_question
if with_answer is not None:
self.with_answer = with_answer
if with_original is not None:
self.with_original = with_original
# Lazy-loaded prompt text
self._prompt_text: Optional[str] = None
self._system_prompt_text: Optional[str] = None
# === Prompt loading ===
@property
def prompt_text(self) -> str:
"""Load and return the prompt text (lazy)."""
if self._prompt_text is None:
self._prompt_text = self._load_prompt(self.prompt)
return self._prompt_text
@property
def system_prompt_text(self) -> str:
"""Load and return the system prompt text (lazy)."""
if self._system_prompt_text is None:
self._system_prompt_text = self._load_prompt(self.system_prompt)
return self._system_prompt_text
def _load_prompt(self, prompt: str) -> str:
"""Load prompt from file or return as-is."""
if not prompt:
return ""
if prompt.endswith(".txt"):
path = Path(prompt)
if not path.exists():
raise FileNotFoundError(f"Prompt file not found: {prompt}")
return path.read_text(encoding="utf-8")
return prompt
# === Build requests ===
[docs]
def build_request(
self,
item: "DataItem",
qa_idx: int,
copy_img: bool = False,
) -> Dict:
"""
Build a single request.
Args:
item: DataItem to build request for.
qa_idx: QA pair index (-1 for batch_qa mode).
copy_img: Whether to deep copy the image (for multiple requests per item).
Returns:
Dict with keys: payload, item, qa_idx
"""
content = self.format_content(item, qa_idx)
image = item.image if self.with_image else None
if copy_img and image is not None:
image = copy_image(image)
payload = self._build_payload(content, image)
return {
"payload": payload,
"item": item,
"qa_idx": qa_idx,
}
[docs]
def build_requests(
self,
item: "DataItem",
batch_qa: bool = False,
) -> List[Dict]:
"""
Build all requests for a DataItem.
Args:
item: DataItem to build requests for.
batch_qa: If True, combine all QA pairs into one request.
Returns:
List of request dicts.
"""
if batch_qa:
return [self.build_request(item, qa_idx=-1)]
requests = []
for qa_idx in range(item.qa_count):
# First request uses original image, subsequent ones use copies
req = self.build_request(item, qa_idx, copy_img=(qa_idx > 0))
requests.append(req)
return requests
[docs]
def build_requests_selective(
self,
item: "DataItem",
qa_indices: List[int],
) -> List[Dict]:
"""
Build requests for specific QA indices only.
Used by SelectiveMLLMRewriter for conditional rewriting.
Args:
item: DataItem to build requests for.
qa_indices: List of QA indices to build requests for.
Returns:
List of request dicts.
"""
requests = []
for i, qa_idx in enumerate(qa_indices):
# First request uses original image, subsequent ones use copies
req = self.build_request(item, qa_idx, copy_img=(i > 0))
requests.append(req)
return requests
# === Format content ===
[docs]
def format_content(self, item: "DataItem", qa_idx: int) -> str:
"""
Format request content.
Args:
item: DataItem to format.
qa_idx: QA pair index (-1 for all QA).
Returns:
Formatted prompt string.
"""
# Simple mode: no prompt and single QA, return raw question only
if not self.prompt_text and qa_idx != -1:
qa = item.qa_pairs[qa_idx]
parts = []
if self.with_question:
parts.append(qa.question)
if self.with_answer:
parts.append(qa.answer)
if self.with_original and qa.ori_answer:
parts.append(qa.ori_answer)
return "\n".join(parts)
if qa_idx == -1:
qa_content = item.format_all_qa(
with_question=self.with_question,
with_answer=self.with_answer,
with_original=self.with_original,
)
else:
qa_content = item.format_qa(
qa_idx,
with_question=self.with_question,
with_answer=self.with_answer,
with_original=self.with_original,
)
if self.prompt_text:
return f"{self.prompt_text}\n{qa_content}"
return qa_content
def _build_payload(self, content: str, image: Any = None) -> List[Dict]:
"""Build payload for model.generate()."""
payload = []
if self.system_prompt_text:
payload.append({"type": "system", "value": self.system_prompt_text})
payload.append({"type": "text", "value": content})
if image is not None:
payload.append({"type": "image", "value": image})
return payload
# === Parse response ===
[docs]
def parse_response(
self,
response: Any,
qa_idx: int,
logger: Optional[Any] = None,
) -> Dict[str, Any]:
"""
Parse model response.
Args:
response: Raw response from model.
qa_idx: QA index (used to format key_templates).
logger: Optional logger for warnings.
Returns:
Dict with parsed fields, e.g. {"result": ..., "reason": ...}
Returns {"result": None, "parse_error": True} if parsing fails.
"""
if not isinstance(response, dict) or self.key_templates is None:
return {"result": response}
parsed = {}
missing_keys = []
for field_name, template in self.key_templates.items():
key = template.format(idx=qa_idx)
if key in response:
parsed[field_name] = response[key]
else:
missing_keys.append(key)
# Warn if expected keys are missing
if missing_keys and logger:
logger.warning(
f"Response missing expected keys: {missing_keys}. "
f"Response keys: {list(response.keys())}"
)
# If no fields were parsed, mark as parse error
if not parsed and self.key_templates:
return {"result": None, "parse_error": True}
return parsed
__all__ = ["RequestBuilder", "copy_image"]