Source code for datastudio.pipelines.pipeline

"""Main pipeline that composes sub-pipelines with priority-based ordering.

Example config::

    pipeline = dict(
        type='Pipeline',
        operations={
            'basic_filters': dict(
                cfg=dict(type='SubPipeline', operators=[
                    dict(type='ConvLengthFilter', min_length=1),
                    dict(type='RemoveThinkRewriter'),
                ]),
                priority=0,
            ),
            'mllm_filter': dict(
                cfg=dict(type='SubPipeline', operators=[
                    dict(type='MLLMFilter', ...),
                ]),
                priority=1,
            ),
        }
    )
"""

from typing import Any, Dict, List, Optional, Tuple

from datastudio.utils.qa_utils import get_qa_count
from datastudio.utils.registry import MODELS, OPERATORS, PIPELINES, build_from_cfg
from datastudio.utils.statistics import StatsCollector

from .sub_pipeline import SubPipeline


[docs] @PIPELINES.register_module() class Pipeline: """Composes sub-pipelines and executes them by priority (lower = earlier). Operators are resolved from the unified ``OPERATORS`` registry, so sub-pipelines can freely mix filters and rewriters. Attributes: sub_pipelines: List of (sub_pipeline, name) tuples, sorted by priority. """
[docs] def __init__( self, operations: Dict[str, Dict], logger: Any, model: Optional[Any] = None, stats_collector: Optional[StatsCollector] = None, ): """ Initialize the pipeline. Args: operations: Dict of operation configs with cfg and priority. logger: Logger instance. model: Optional model instance for MLLM operators. stats_collector: Optional statistics collector for tracking operator stats. """ self.logger = logger self.model = model self.stats_collector = stats_collector logger.info(f"Building Pipeline with {len(operations)} sub-pipelines...") # Sort by priority (lower = earlier) sorted_ops = sorted(operations.items(), key=lambda x: x[1].get("priority", 0)) logger.info(f"Execution order: {[op[0] for op in sorted_ops]}") # Build sub-pipelines self.sub_pipelines: List[Tuple[SubPipeline, str]] = [] for op_name, op in sorted_ops: op_cfg = op.get("cfg", {}) priority = op.get("priority", 0) # Build the sub-pipeline using unified OPERATORS registry sub_pipeline = self._build_sub_pipeline( op_cfg, op_name, priority, logger, model ) self.sub_pipelines.append((sub_pipeline, op_name)) logger.info(f" [{priority}] {op_name}")
def _build_sub_pipeline( self, cfg: Dict, name: str, priority: int, logger: Any, model: Optional[Any], ) -> SubPipeline: """ Build a sub-pipeline from config. Args: cfg: Sub-pipeline config dict. name: Name of the sub-pipeline. priority: Execution priority. logger: Logger instance. model: Optional model for MLLM operators. Returns: SubPipeline instance. """ cfg_type = cfg.get("type", "") if isinstance(cfg, dict) else "" if cfg_type != "SubPipeline": raise ValueError(f"Expected SubPipeline config, got '{cfg_type}'") operators = [] for op_cfg in cfg.get("operators", []): # Support per-operator model override: if an operator config contains # a "model" dict, build a dedicated model instance for that operator # instead of using the global one. op_model = model if isinstance(op_cfg, dict) and isinstance(op_cfg.get("model"), dict): op_model = build_from_cfg(op_cfg["model"], MODELS, logger=logger) op_cfg = {k: v for k, v in op_cfg.items() if k != "model"} op = build_from_cfg(op_cfg, OPERATORS, logger=logger, model=op_model) if op is not None: operators.append(op) return SubPipeline(operators, name=name, priority=priority, logger=logger) def __call__(self, datas: List[Dict]) -> List[Dict]: """ Execute all sub-pipelines in priority order. Args: datas: List of raw data dicts to process. Returns: List of all processed data dicts (both kept and rejected). Rejected items have 'rejected=True' in their dict. """ all_rejected: List[Dict] = [] for sub_pipeline, name in self.sub_pipelines: before_count = len(datas) # Calculate QA counts before processing before_qa_count = sum( get_qa_count(d.get("conversations", [])) for d in datas ) # SubPipeline returns merged list, split for counting result = sub_pipeline(datas) datas = [d for d in result if not d.get("rejected", False)] newly_rejected = [d for d in result if d.get("rejected", False)] all_rejected.extend(newly_rejected) after_count = len(datas) newly_rejected_count = len(newly_rejected) # Calculate QA counts after processing after_qa_count = sum( get_qa_count(d.get("conversations", [])) for d in datas ) qa_rejected_from_input = before_qa_count - after_qa_count # Count items rewritten by this sub_pipeline operator_names = [op.name for op in sub_pipeline.operators] rewritten_count = sum( 1 for d in datas if any( op_name in d.get("rewrite_ops", {}) for op_name in operator_names ) ) # Calculate split stats: when partial rejection happens, # one input item becomes two output items (kept + rejected) total_output_items = after_count + newly_rejected_count split_count = ( total_output_items - before_count ) # Number of items that got split if self.logger: # Log both item-level and QA-level statistics log_msg = f"{name}: items {before_count}->{after_count} kept, {newly_rejected_count} rejected" if split_count > 0: log_msg += f" ({split_count} partial splits)" log_msg += f" | QA {before_qa_count}->{after_qa_count} kept, {qa_rejected_from_input} rejected" self.logger.info(log_msg) # Record operator statistics if self.stats_collector: self.stats_collector.record_operator( name=name, input_count=before_count, output_count=after_count, rejected_count=newly_rejected_count, rewritten_count=rewritten_count, input_qa_count=before_qa_count, output_qa_count=after_qa_count, rejected_qa_count=qa_rejected_from_input, split_count=split_count, ) return datas + all_rejected def __del__(self): try: if hasattr(self, "logger") and self.logger: self.logger.info("Pipeline completed.") except Exception: pass # Ignore errors during cleanup