Source code for datastudio.operators.rewriters.norm_think

"""Normalize <think> tags rewriter."""

import re
from typing import Any, Optional

from datastudio.utils.registry import OPERATORS

from ..core import DataItem, Rewriter


[docs] @OPERATORS.register_module() class NormThinkRewriter(Rewriter): """ Normalize <think> tag format in answers. Ensures consistent format: - Lowercase tags - Newline after opening tag - Newline before closing tag Example: Input: "<THINK>reasoning</THINK>answer" Output: "<think>\\nreasoning\\n</think>\\nanswer" """ # Patterns for different think tag formats THINK_VARIANTS = [ (re.compile(r"<THINK>", re.IGNORECASE), "<think>"), (re.compile(r"</THINK>", re.IGNORECASE), "</think>"), (re.compile(r"<think>(?!\n)"), "<think>\n"), (re.compile(r"(?<!\n)</think>"), "\n</think>"), (re.compile(r"</think>(?!\n)(?=\S)"), "</think>\n"), ]
[docs] def __init__( self, logger: Optional[Any] = None, **kwargs, ): """ Initialize the rewriter. Args: logger: Logger instance. """ super().__init__(logger=logger, **kwargs) if logger: logger.info(f"[{self.name}] Initialized")
[docs] def rewrite(self, item: DataItem, qa_idx: int) -> Optional[str]: """Normalize think tags in an answer.""" qa = item.get_qa(qa_idx) new_answer = self._normalize(qa.answer) if new_answer != qa.answer: return new_answer return None
def _normalize(self, text: str) -> str: """Apply all normalization patterns.""" result = text for pattern, replacement in self.THINK_VARIANTS: result = pattern.sub(replacement, result) return result