Development Guide
This document provides detailed instructions for developing custom operators in DataStudio.
Operator Architecture
DataStudio provides two base operator types:
Base Class |
Purpose |
Key Method |
|---|---|---|
|
Decide whether to keep or reject QA pairs |
|
|
Transform or modify QA pair content |
|
All operators are registered via @OPERATORS.register_module() and can be referenced by class name in pipeline configs.
Creating a Filter
Filters implement check() which returns (rejected: bool, reason: str). If rejected=True, the QA pair is filtered out.
# datastudio/operators/filters/my_filter.py
from typing import Any, Optional, Tuple
from datastudio.utils.registry import OPERATORS
from datastudio.operators.core import Filter, DataItem, Result
@OPERATORS.register_module()
class MyCustomFilter(Filter):
"""Filter QA pairs based on answer length.
Args:
min_length: Minimum answer length to keep.
max_length: Maximum answer length to keep.
"""
def __init__(
self,
min_length: int = 10,
max_length: int = 8192,
logger: Optional[Any] = None,
**kwargs,
):
super().__init__(logger=logger, **kwargs)
self.min_length = min_length
self.max_length = max_length
def check(self, item: DataItem, qa_idx: int) -> Tuple[bool, str]:
"""Check if the QA pair should be rejected.
Args:
item: The data item containing the QA pairs.
qa_idx: Index of the QA pair to check.
Returns:
Tuple of (rejected, reason).
"""
answer = item.get_answer(qa_idx)
length = len(answer)
if length < self.min_length:
return True, f"Answer too short: {length} < {self.min_length}"
if length > self.max_length:
return True, f"Answer too long: {length} > {self.max_length}"
return False, ""
Item-level vs QA-level Filtering
QA-level (default): Override
check(). The baseprocess()callscheck()for each QA pair.Item-level: Override
process()directly for conditions that apply to the entire item (e.g., conversation length, image format).
@OPERATORS.register_module()
class ItemLevelFilter(Filter):
"""Example of an item-level filter."""
def process(self, item: DataItem) -> Result:
result = Result(item_idx=item.idx)
if item.qa_count > 20:
result.add_filter(qa_idx=-1, rejected=True, reason="Too many QA pairs")
return result
def check(self, item: DataItem, qa_idx: int) -> Tuple[bool, str]:
return False, "" # Not used
Creating a Rewriter
Rewriters implement rewrite() to transform the answer text of a QA pair.
# datastudio/operators/rewriters/my_rewriter.py
from typing import Any, Optional
from datastudio.utils.registry import OPERATORS
from datastudio.operators.core import Rewriter, DataItem
@OPERATORS.register_module()
class MyCustomRewriter(Rewriter):
"""Strip leading/trailing whitespace from answers."""
def __init__(self, logger: Optional[Any] = None, **kwargs):
super().__init__(logger=logger, **kwargs)
def rewrite(self, item: DataItem, qa_idx: int) -> Optional[str]:
"""Rewrite the answer of a QA pair.
Args:
item: The data item.
qa_idx: QA pair index.
Returns:
The rewritten answer string, or None if no change.
"""
answer = item.get_answer(qa_idx)
stripped = answer.strip()
return stripped if stripped != answer else None
Registration and Export
1. Register with the decorator
from datastudio.utils.registry import OPERATORS
@OPERATORS.register_module()
class MyCustomFilter(Filter):
...
2. Export in __init__.py
# datastudio/operators/filters/__init__.py
from .my_filter import MyCustomFilter
__all__ = [
# ... existing exports
"MyCustomFilter",
]
Using in Config
Once registered, reference operators by class name in config files:
# configs/my_pipeline.py
_base_ = ["@/_base_/dataset.py"]
dataset_yaml = "path/to/dataset.yaml"
sub_pipelines = dict(
filters=dict(
cfg=dict(type="SubPipeline", operators=[
dict(type="MyCustomFilter", min_length=10, max_length=4096),
]),
priority=0,
),
rewriters=dict(
cfg=dict(type="SubPipeline", operators=[
dict(type="MyCustomRewriter"),
]),
priority=1,
),
)
Writing Tests
Place test files in tests/ with the test_ prefix:
# tests/test_my_filter.py
import pytest
from datastudio.operators.core import DataItem, Result
from datastudio.operators.filters.my_filter import MyCustomFilter
class TestMyCustomFilter:
def test_reject_short_answer(self):
"""Should reject answers shorter than min_length."""
f = MyCustomFilter(min_length=10, max_length=100)
item = DataItem(
{"messages": [
{"role": "user", "content": "question"},
{"role": "assistant", "content": "short"},
]},
idx=0,
)
rejected, reason = f.check(item, qa_idx=0)
assert rejected is True
assert "too short" in reason.lower()
def test_keep_valid_answer(self):
"""Should keep answers within length bounds."""
f = MyCustomFilter(min_length=1, max_length=1000)
item = DataItem(
{"messages": [
{"role": "user", "content": "question"},
{"role": "assistant", "content": "This is a valid answer."},
]},
idx=0,
)
rejected, reason = f.check(item, qa_idx=0)
assert rejected is False
@pytest.mark.parametrize("min_len,max_len", [(1, 5), (10, 100), (50, 500)])
def test_various_bounds(self, min_len, max_len):
"""Should initialize with various configurations."""
f = MyCustomFilter(min_length=min_len, max_length=max_len)
assert f.min_length == min_len
assert f.max_length == max_len
Run tests:
python -m pytest tests/test_my_filter.py -v