Source code for datastudio.operators.rewriters.split

"""Split multi-turn conversations rewriter."""

import copy
from typing import Any, Dict, List, Optional

from datastudio.utils.registry import OPERATORS

from ..core import DataItem, Operator, Result


[docs] @OPERATORS.register_module() class SplitRewriter(Operator): """ Split multi-turn conversations into single-turn conversations. Each QA pair becomes a separate data item, preserving all metadata and adjusting indices accordingly. Note: This is a special Operator that returns multiple items, handled specially by Pipeline. """
[docs] def __init__( self, logger: Optional[Any] = None, **kwargs, ): """ Initialize the rewriter. Args: logger: Logger instance. """ super().__init__(logger=logger, **kwargs) if logger: logger.info(f"[{self.name}] Initialized")
[docs] def process(self, item: DataItem) -> Result: """ Split is handled specially - we mark items for expansion. The actual splitting is done in expand_items(). """ return Result(item_idx=item.idx)
[docs] def process_batch(self, items: List[DataItem]) -> List[Result]: """Process batch - returns keep result for each item.""" return [Result(item_idx=item.idx) for item in items]
[docs] def expand_items(self, items: List[DataItem]) -> List[DataItem]: """ Expand multi-turn items into single-turn items. This should be called directly instead of through Pipeline for splitting operations. """ expanded = [] for item in items: if item.qa_count <= 1: expanded.append(item) continue # Split each QA pair into separate item for qa_idx in range(item.qa_count): new_data = self._extract_single_turn(item.data, qa_idx, item.qa_count) expanded.append(DataItem(new_data, idx=len(expanded))) return expanded
def _extract_single_turn( self, data: Dict, qa_idx: int, total_qa_count: int, ) -> Dict: """Extract a single turn from multi-turn data.""" new_data = copy.deepcopy(data) # Extract single conversation pair new_data["conversations"] = data["conversations"][qa_idx * 2 : (qa_idx + 1) * 2] # Re-index metadata fields self._reindex_metadata(new_data, qa_idx) # Add split record if "rewrite_ops" not in new_data: new_data["rewrite_ops"] = {} new_data["rewrite_ops"][self.name] = { 0: f"split from {total_qa_count}-turn conversation (was qa_{qa_idx})" } return new_data def _reindex_metadata(self, data: Dict, target_idx: int): """Re-index metadata for extracted single turn.""" metadata_fields = ["ori_answer", "ori_question", "score"] for field in metadata_fields: if field not in data or not isinstance(data[field], dict): continue original = data[field] value = original.get(str(target_idx)) if value is not None: data[field] = {"0": value} else: data[field] = {} # Handle nested ops records for ops_field in ["filter_ops", "rewrite_ops"]: if ops_field not in data or not isinstance(data[ops_field], dict): continue new_ops = {} for op_name, records in data[ops_field].items(): if isinstance(records, dict): value = records.get(target_idx) or records.get(str(target_idx)) if value is not None: new_ops[op_name] = {0: value} else: # Non-dict record (e.g., full-item filter reason) new_ops[op_name] = records data[ops_field] = new_ops