Source code for datastudio.datasets.data_loader

"""Dataset loader with LMDB image caching and checkpoint resume support."""

import io
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Iterator, List, Union

from PIL import Image
from tqdm import tqdm

from datastudio.datasets.formatters import FormatRegistry
from datastudio.utils.database import ShardedLMDBManager, write_images_to_sharded_lmdb
from datastudio.utils.qa_utils import calculate_qa_statistics
from datastudio.utils.registry import DATALOADER
from datastudio.utils.vision import apply_resize_image_pil, load_and_resize_image


[docs] @DATALOADER.register_module() class StandardDataLoader: """Dataset loader with LMDB image caching and checkpoint resume. Supports automatic format detection, parallel image loading, adaptive batch sizing, and item-level checkpoint for resumable processing. Example:: loader = StandardDataLoader( data_root='/data/datasets', dataset={'file_path': 'train.jsonl'}, batch_size=32, logger=logger, ) for batch in loader: process(batch) """
[docs] def __init__( self, data_root, dataset, batch_size, parallel_loading=True, num_workers=256, logger=None, cache_dir="~/cache/images_lmdb_sharded", lmdb_num_shards=32, lmdb_map_size_per_shard=1024 * 1024 * 1024 * 1024, # 1TB per shard lmdb_readonly=False, lmdb_lock=False, resize_image=True, resize_image_size=1024, use_image=True, use_lmdb_cache=True, # Whether to cache images in LMDB; False = read from disk directly adjust_batch_size=True, # Whether to adjust batch_size based on turns checkpoint_manager=None, # Checkpoint manager for item-level resume **kwargs, # Accept but ignore deprecated parameters like pre_load_img ): self.data_root = data_root self.batch_size = batch_size self.original_config = None self.parallel_loading = parallel_loading self.num_workers = num_workers self.logger = logger self.resize_image = resize_image self.resize_image_size = resize_image_size self.use_image = use_image self.use_lmdb_cache = use_lmdb_cache self.adjust_batch_size = adjust_batch_size self.checkpoint_manager = checkpoint_manager self.name = "" # Store LMDB config for lazy initialization self._cache_dir = os.path.expanduser(cache_dir) if cache_dir else None self._lmdb_num_shards = lmdb_num_shards self._lmdb_map_size_per_shard = lmdb_map_size_per_shard self._lmdb_readonly = lmdb_readonly self._lmdb_lock = lmdb_lock self.sharded_lmdb = None # Resolve dataset path and load data (supports JSON/JSONL formats) data_file = self._resolve_data_path(dataset, data_root) self.data_file = data_file self.all_datas = FormatRegistry.load( data_file, add_source_file=True, remove_rejected=True ) self.total_items = len(self.all_datas) # Apply checkpoint: skip already processed items self.start_idx = 0 if self.checkpoint_manager: self.start_idx = self.checkpoint_manager.get_start_index( data_file, self.total_items ) if self.start_idx == -1: # File already completed self.datas = [] self.start_idx = self.total_items else: self.datas = self.all_datas[self.start_idx :] else: self.datas = self.all_datas # Calculate QA statistics and adjust batch_size stats = calculate_qa_statistics( self.datas, batch_size=batch_size, adjust_batch_size=adjust_batch_size, logger=logger, ) self.avg_qa_count = stats["avg_qa_count"] self.max_qa_count = stats["max_qa_count"] self.total_qa_pairs = stats["total_qa_pairs"] self.real_batch_size = stats["real_batch_size"] if self.logger: self.logger.info( f"Dataset size: {len(self.datas)}/{self.total_items} (start_idx={self.start_idx}), " f"avg turns: {self.avg_qa_count:.2f}, max turns: {self.max_qa_count}, " f"adjusted batch_size: {self.real_batch_size}, " f"resize: {self.resize_image}, image_max_size: {self.resize_image_size}" ) self.idx = 0 # Track current absolute index for checkpoint updates self.current_abs_idx = self.start_idx
def _ensure_lmdb_initialized(self) -> None: """Ensure LMDB manager is initialized (lazy initialization).""" if self.sharded_lmdb is None: self.sharded_lmdb = ShardedLMDBManager.get_instance( cache_dir=self._cache_dir, num_shards=self._lmdb_num_shards, map_size_per_shard=self._lmdb_map_size_per_shard, readonly=self._lmdb_readonly, lock=self._lmdb_lock, logger=self.logger, ) def _resolve_data_path(self, dataset: Union[str, Dict], data_root: str) -> str: """Resolve dataset file path (supports JSON/JSONL formats). Args: dataset: Dataset path string or config dict. data_root: Root directory for relative paths. Returns: Absolute path to the data file. """ if isinstance(dataset, dict): # Support multiple field names: file_path, json_path, jsonl_path, etc. from datastudio.datasets.config import DatasetConfig config = DatasetConfig.from_dict(dataset) if not config.file_path: if self.logger: self.logger.error( f"Missing file path in dataset configuration: {dataset}" ) raise ValueError("Missing file path (file_path/json_path/jsonl_path)") data_file = config.file_path if not os.path.isabs(data_file): data_file = os.path.join(data_root, data_file) else: data_file = os.path.join(data_root, dataset) self.name = data_file return data_file def _get_adaptive_resize_size(self, image_count: int) -> int: """Get adaptive resize size based on image count. Args: image_count: Number of images in the item. Returns: Resize size (max edge length) based on image count: - 1-5 images: use original resize_image_size - 6-10 images: 448 - 11-15 images: 512 - 15+ images: 392 """ if image_count > 15: return 392 elif image_count > 10: return 448 elif image_count > 5: return 672 elif image_count > 2: return 1024 else: return self.resize_image_size def _load_images_for_batch(self, batch: List[Dict]) -> List[Dict]: """Load images for a batch from LMDB cache or disk. When ``use_lmdb_cache=True``, images are cached to LMDB first and then read back. When ``use_lmdb_cache=False``, images are read directly from disk. """ if not self.use_image: return batch if not self.use_lmdb_cache: # Direct mode: multi-threaded loading from disk # Collect all (item_index, img_path, resize_size) tasks tasks = [] for i, item in enumerate(batch): img_path = item.get("image") if img_path and isinstance(img_path, str): actual_path = ( img_path if os.path.isabs(img_path) else os.path.join(self.data_root, img_path) ) resize_size = self.resize_image_size if self.resize_image else 0 tasks.append((i, actual_path, resize_size, False)) elif img_path and isinstance(img_path, list): adaptive_size = ( self._get_adaptive_resize_size(len(img_path)) if self.resize_image else 0 ) for j, img in enumerate(img_path): actual_img = ( img if os.path.isabs(img) else os.path.join(self.data_root, img) ) tasks.append((i, actual_img, adaptive_size, True, j)) if tasks: num_workers = min( self.num_workers if self.parallel_loading else 1, len(tasks) ) with ThreadPoolExecutor(max_workers=num_workers) as executor: futures = {} for task in tasks: future = executor.submit( load_and_resize_image, task[1], task[2] ) futures[future] = task for future in futures: task = futures[future] img_pil = future.result() item_idx = task[0] is_list = task[3] if not is_list: batch[item_idx]["image_pil"] = img_pil else: # Initialize list if needed if "image_pil" not in batch[item_idx]: batch[item_idx]["image_pil"] = [None] * len( batch[item_idx]["image"] ) batch[item_idx]["image_pil"][task[4]] = img_pil return batch self._ensure_lmdb_initialized() # Step 1: Collect unique image paths for this batch image_paths_set = set() for item in batch: img_path = item.get("image") if img_path: if isinstance(img_path, list): image_paths_set.update(img_path) else: image_paths_set.add(img_path) image_paths = list(image_paths_set) if image_paths: # Step 2: Write batch images to LMDB (skips already-cached ones internally) write_images_to_sharded_lmdb( self.sharded_lmdb, image_paths, data_root=self.data_root, batch_size=2000, resize_image=self.resize_image, resize_image_size=self.resize_image_size, num_workers=self.num_workers if self.parallel_loading else 1, ) # Step 3: Batch read unique images from LMDB if image_paths: img_bytes_map = self.sharded_lmdb.batch_get(image_paths) else: img_bytes_map = {} # Step 4: Decode each unique image ONCE (avoid decoding same image N times) def _decode_one(path_and_bytes): """Decode PNG bytes to PIL Image.""" path, raw_bytes = path_and_bytes try: img_pil = Image.open(io.BytesIO(raw_bytes)) img_pil.load() return path, img_pil except Exception: return path, None decode_items = [(p, b) for p, b in img_bytes_map.items() if b is not None] img_pil_map = {} if decode_items: num_workers = min( self.num_workers if self.parallel_loading else 1, len(decode_items), 32, ) show_decode_progress = len(decode_items) >= 1000 and self.logger is not None if show_decode_progress: decode_pbar = tqdm( total=len(decode_items), desc="Image decode", unit="img", ) with ThreadPoolExecutor(max_workers=num_workers) as executor: for path, img_pil in executor.map(_decode_one, decode_items): if img_pil is not None: img_pil_map[path] = img_pil if show_decode_progress: decode_pbar.update(1) if show_decode_progress: decode_pbar.close() # Step 5: Pre-resize images and cache by (path, size) to avoid repeated resize # Items share the same PIL Image object (read-only in downstream pipeline) _resized_cache = {} def _get_resized(path, resize_size): """Get resized image, caching by (path, size) to avoid redundant work.""" cache_key = (path, resize_size) if cache_key not in _resized_cache: img_pil = img_pil_map.get(path) if img_pil is None: _resized_cache[cache_key] = None elif self.resize_image: _resized_cache[cache_key] = apply_resize_image_pil( img_pil, resize_size ) else: _resized_cache[cache_key] = img_pil return _resized_cache[cache_key] # Step 6: Distribute to batch items (shared references, no copy) for item in batch: img_path = item.get("image") if img_path and isinstance(img_path, str): item["image_pil"] = _get_resized(img_path, self.resize_image_size) elif img_path and isinstance(img_path, list): image_count = len(img_path) adaptive_size = self._get_adaptive_resize_size(image_count) item["image_pil"] = [ _get_resized(img, adaptive_size) for img in img_path ] return batch
[docs] def get_name(self) -> str: """Get the dataset name.""" return self.name
def __len__(self) -> int: """Return the number of batches.""" return (len(self.datas) + self.real_batch_size - 1) // self.real_batch_size def __iter__(self) -> Iterator[List[Dict]]: """Initialize iteration.""" self.idx = 0 self.current_abs_idx = self.start_idx return self def __next__(self) -> List[Dict]: if self.idx >= len(self.datas): raise StopIteration if self.idx + self.real_batch_size > len(self.datas): batch_data = self.datas[self.idx :] self.idx = len(self.datas) else: batch_data = self.datas[self.idx : self.idx + self.real_batch_size] self.idx += self.real_batch_size # Update absolute index for checkpoint self.current_abs_idx = self.start_idx + self.idx # Load images for this batch on-demand batch_data = self._load_images_for_batch(batch_data) return batch_data
[docs] def update_checkpoint(self) -> None: """Update checkpoint with current progress.""" if self.checkpoint_manager: self.checkpoint_manager.update( self.name, self.current_abs_idx, self.total_items )
[docs] def is_completed(self) -> bool: """Check if this dataloader has been fully processed.""" if self.checkpoint_manager: return self.checkpoint_manager.is_completed(self.name) return False
[docs] def has_checkpoint(self) -> bool: """Check if this dataloader has checkpoint tracking enabled and has progress.""" if self.checkpoint_manager: return self.checkpoint_manager.get_progress(self.name) is not None return False
[docs] def is_empty(self) -> bool: """Check if there are no items to process.""" return len(self.datas) == 0
def __call__(self): if not self.datas: return [] if self.idx >= len(self.datas): self.idx = 0 if self.idx + self.real_batch_size > len(self.datas): batch_data = self.datas[self.idx :] self.idx = len(self.datas) else: batch_data = self.datas[self.idx : self.idx + self.real_batch_size] self.idx += self.real_batch_size # Update absolute index for checkpoint self.current_abs_idx = self.start_idx + self.idx # Load images for this batch on-demand batch_data = self._load_images_for_batch(batch_data) return batch_data