World_Model / URSA /src /distill /prompt_dataset.py
BryanW's picture
Add files using upload-large-folder tool
d403233 verified
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
# Licensed under Apache License, Version 2.0
"""Prompt-only dataset for on-policy distillation.
Supports:
1) .txt: one prompt per line (small files)
2) .csv shards (Koala-36M): stream caption column from many CSV files
This is an IterableDataset to avoid loading 10x 4.89GB CSV shards into memory.
Distributed sharding
--------------------
``__iter__`` automatically detects whether torch.distributed is initialised
and shards files across ranks using file-level slicing:
rank-r reads ``files[r::world_size]``
This keeps each rank's file I/O independent, avoids duplicate prompts within
a global batch, and requires zero coordination overhead.
"""
import csv
import glob
import os
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Iterator, List, Optional, Union
import torch
from torch.utils.data import IterableDataset
@dataclass
class CSVSpec:
caption_field: str = "caption"
clarity_field: str = "clarity_score"
aesthetic_field: str = "aesthetic_score"
min_clarity: Optional[float] = None
min_aesthetic: Optional[float] = None
def _expand_sources(path_or_glob: str) -> List[str]:
"""Accept:
- file path
- directory (all *.csv inside)
- glob (Koala_36M_*.csv)
- comma-separated list of any of the above
"""
parts = [p.strip() for p in path_or_glob.split(",") if p.strip()]
out: List[str] = []
for p in parts:
if any(ch in p for ch in ["*", "?", "[", "]"]):
out.extend(sorted(glob.glob(p)))
else:
pp = Path(p)
if pp.is_dir():
out.extend(sorted(str(x) for x in pp.glob("*.csv")))
else:
out.append(str(pp))
out = [x for x in out if os.path.exists(x)]
if not out:
raise FileNotFoundError(f"No files found for prompt source: {path_or_glob}")
return out
def _maybe_float(x: str) -> Optional[float]:
try:
return float(x)
except Exception:
return None
class PromptDataset(IterableDataset):
"""Stream prompts from txt or csv shards.
Args:
prompt_source: path/dir/glob/comma-list. For Koala, pass something like:
--prompt_file "/data/Koala_36M_*.csv"
shuffle_files: randomize shard order each epoch.
shuffle_buffer: >0 enables approximate shuffle within a sliding buffer.
seed: RNG seed.
infinite: if True, loops over shards forever (recommended for num_steps training).
csv: CSVSpec for caption field and optional score filtering.
encoding: file encoding.
"""
def __init__(
self,
prompt_source: str,
shuffle_files: bool = True,
shuffle_buffer: int = 0,
seed: int = 42,
infinite: bool = True,
csv: Optional[CSVSpec] = None,
encoding: str = "utf-8",
):
super().__init__()
self.files = _expand_sources(prompt_source)
self.shuffle_files = shuffle_files
self.shuffle_buffer = int(shuffle_buffer)
self.seed = int(seed)
self.infinite = bool(infinite)
self.csvspec = csv or CSVSpec()
self.encoding = encoding
# Decide mode by extension of the first resolved file
first = self.files[0].lower()
if first.endswith(".txt"):
self.mode = "txt"
elif first.endswith(".csv"):
self.mode = "csv"
else:
raise ValueError(f"Unsupported prompt source type: {self.files[0]} (expect .txt or .csv)")
# -----------------------------------------------------------------------
# Distributed sharding helpers
# -----------------------------------------------------------------------
def _get_dist_files(self) -> List[str]:
"""Return files assigned to this distributed rank (file-level sharding).
With world_size=8 and 80 CSV shards, rank-r reads shards
[r, r+8, r+16, ...]. Falls back to all files when:
- torch.distributed is not initialised, OR
- world_size == 1, OR
- sliced list would be empty (fewer files than ranks).
"""
files = self.files
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
rank = dist.get_rank()
world_size = dist.get_world_size()
sliced = files[rank::world_size]
files = sliced if sliced else self.files
except Exception:
pass
return files
def _get_rank(self) -> int:
"""Return current distributed rank (0 when not initialised)."""
try:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
except Exception:
pass
return 0
# -----------------------------------------------------------------------
# Iteration helpers
# -----------------------------------------------------------------------
def _iter_txt(self, rng: random.Random, files: List[str] = None) -> Iterator[str]:
"""Stream lines from .txt files, looping if infinite."""
files = files if files is not None else self.files
while True:
for fp in files:
with open(fp, "r", encoding=self.encoding) as f:
for line in f:
s = line.strip()
if not s or s.startswith("#"):
continue
yield s
if not self.infinite:
break
def _iter_csv_file(self, fp: str) -> Iterator[str]:
# Koala CSV: yield caption column, optionally filter on scores.
cs = self.csvspec
with open(fp, "r", encoding=self.encoding, newline="") as f:
reader = csv.DictReader(f)
# Validate schema once
if cs.caption_field not in reader.fieldnames:
raise KeyError(
f"CSV missing caption field '{cs.caption_field}'. "
f"Got fields: {reader.fieldnames[:20]}..."
)
for row in reader:
cap = (row.get(cs.caption_field) or "").strip()
if not cap:
continue
# Optional filters (if enabled)
if cs.min_clarity is not None:
v = _maybe_float(row.get(cs.clarity_field, ""))
if v is None or v < cs.min_clarity:
continue
if cs.min_aesthetic is not None:
v = _maybe_float(row.get(cs.aesthetic_field, ""))
if v is None or v < cs.min_aesthetic:
continue
yield cap
def _iter_csv(self, rng: random.Random, files: List[str] = None) -> Iterator[str]:
"""Iterate CSV shards; approximate shuffle with buffer if requested."""
files = list(files if files is not None else self.files)
while True:
if self.shuffle_files:
rng.shuffle(files)
if self.shuffle_buffer > 0:
buf: List[str] = []
# Fill and pop randomly
for fp in files:
for cap in self._iter_csv_file(fp):
buf.append(cap)
if len(buf) >= self.shuffle_buffer:
j = rng.randrange(len(buf))
yield buf.pop(j)
# Drain remainder
while buf:
j = rng.randrange(len(buf))
yield buf.pop(j)
else:
for fp in files:
yield from self._iter_csv_file(fp)
if not self.infinite:
break
def __iter__(self) -> Iterator[str]:
# Distributed rank-level file sharding (file-level, assigned once).
dist_files = self._get_dist_files()
rank = self._get_rank()
# Each dataloader worker gets its own RNG (independent per worker AND
# per distributed rank so shuffles don't collide across processes).
wi = torch.utils.data.get_worker_info()
worker_id = 0 if wi is None else wi.id
rng = random.Random(self.seed + 1009 * worker_id + 97 * rank)
if self.mode == "txt":
yield from self._iter_txt(rng, dist_files)
else:
yield from self._iter_csv(rng, dist_files)
def make_collate_fn(tokenizer, max_prompt_length: int, device: torch.device):
"""Tokenize List[str] -> [B, L] tensor.
IMPORTANT: returns a CPU tensor regardless of ``device`` argument.
Move to GPU inside the training step (not in the dataloader worker).
"""
tok_kwargs = {
"max_length": max_prompt_length,
"padding": "max_length",
"padding_side": "left",
"truncation": True,
"return_tensors": "pt",
}
def collate_fn(prompts: List[str]) -> torch.Tensor:
return tokenizer(prompts, **tok_kwargs).input_ids # CPU tensor
return collate_fn
class InfiniteDataLoader:
"""Wraps a DataLoader and cycles indefinitely (works for both map/iter datasets)."""
def __init__(self, dataloader):
self.dataloader = dataloader
self._iter = iter(dataloader)
def __next__(self):
try:
return next(self._iter)
except StopIteration:
self._iter = iter(self.dataloader)
return next(self._iter)
def __iter__(self):
return self