|
|
|
|
| """Prompt-only dataset for on-policy distillation.
|
|
|
| Supports:
|
| 1) .txt: one prompt per line (small files)
|
| 2) .csv shards (Koala-36M): stream caption column from many CSV files
|
|
|
| This is an IterableDataset to avoid loading 10x 4.89GB CSV shards into memory.
|
|
|
| Distributed sharding
|
| --------------------
|
| ``__iter__`` automatically detects whether torch.distributed is initialised
|
| and shards files across ranks using file-level slicing:
|
| rank-r reads ``files[r::world_size]``
|
| This keeps each rank's file I/O independent, avoids duplicate prompts within
|
| a global batch, and requires zero coordination overhead.
|
| """
|
|
|
| import csv
|
| import glob
|
| import os
|
| import random
|
| from dataclasses import dataclass
|
| from pathlib import Path
|
| from typing import Iterable, Iterator, List, Optional, Union
|
|
|
| import torch
|
| from torch.utils.data import IterableDataset
|
|
|
|
|
| @dataclass
|
| class CSVSpec:
|
| caption_field: str = "caption"
|
| clarity_field: str = "clarity_score"
|
| aesthetic_field: str = "aesthetic_score"
|
| min_clarity: Optional[float] = None
|
| min_aesthetic: Optional[float] = None
|
|
|
|
|
| def _expand_sources(path_or_glob: str) -> List[str]:
|
| """Accept:
|
| - file path
|
| - directory (all *.csv inside)
|
| - glob (Koala_36M_*.csv)
|
| - comma-separated list of any of the above
|
| """
|
| parts = [p.strip() for p in path_or_glob.split(",") if p.strip()]
|
| out: List[str] = []
|
| for p in parts:
|
| if any(ch in p for ch in ["*", "?", "[", "]"]):
|
| out.extend(sorted(glob.glob(p)))
|
| else:
|
| pp = Path(p)
|
| if pp.is_dir():
|
| out.extend(sorted(str(x) for x in pp.glob("*.csv")))
|
| else:
|
| out.append(str(pp))
|
| out = [x for x in out if os.path.exists(x)]
|
| if not out:
|
| raise FileNotFoundError(f"No files found for prompt source: {path_or_glob}")
|
| return out
|
|
|
|
|
| def _maybe_float(x: str) -> Optional[float]:
|
| try:
|
| return float(x)
|
| except Exception:
|
| return None
|
|
|
|
|
| class PromptDataset(IterableDataset):
|
| """Stream prompts from txt or csv shards.
|
|
|
| Args:
|
| prompt_source: path/dir/glob/comma-list. For Koala, pass something like:
|
| --prompt_file "/data/Koala_36M_*.csv"
|
| shuffle_files: randomize shard order each epoch.
|
| shuffle_buffer: >0 enables approximate shuffle within a sliding buffer.
|
| seed: RNG seed.
|
| infinite: if True, loops over shards forever (recommended for num_steps training).
|
| csv: CSVSpec for caption field and optional score filtering.
|
| encoding: file encoding.
|
| """
|
|
|
| def __init__(
|
| self,
|
| prompt_source: str,
|
| shuffle_files: bool = True,
|
| shuffle_buffer: int = 0,
|
| seed: int = 42,
|
| infinite: bool = True,
|
| csv: Optional[CSVSpec] = None,
|
| encoding: str = "utf-8",
|
| ):
|
| super().__init__()
|
| self.files = _expand_sources(prompt_source)
|
| self.shuffle_files = shuffle_files
|
| self.shuffle_buffer = int(shuffle_buffer)
|
| self.seed = int(seed)
|
| self.infinite = bool(infinite)
|
| self.csvspec = csv or CSVSpec()
|
| self.encoding = encoding
|
|
|
|
|
| first = self.files[0].lower()
|
| if first.endswith(".txt"):
|
| self.mode = "txt"
|
| elif first.endswith(".csv"):
|
| self.mode = "csv"
|
| else:
|
| raise ValueError(f"Unsupported prompt source type: {self.files[0]} (expect .txt or .csv)")
|
|
|
|
|
|
|
|
|
|
|
| def _get_dist_files(self) -> List[str]:
|
| """Return files assigned to this distributed rank (file-level sharding).
|
|
|
| With world_size=8 and 80 CSV shards, rank-r reads shards
|
| [r, r+8, r+16, ...]. Falls back to all files when:
|
| - torch.distributed is not initialised, OR
|
| - world_size == 1, OR
|
| - sliced list would be empty (fewer files than ranks).
|
| """
|
| files = self.files
|
| try:
|
| import torch.distributed as dist
|
|
|
| if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
| rank = dist.get_rank()
|
| world_size = dist.get_world_size()
|
| sliced = files[rank::world_size]
|
| files = sliced if sliced else self.files
|
| except Exception:
|
| pass
|
| return files
|
|
|
| def _get_rank(self) -> int:
|
| """Return current distributed rank (0 when not initialised)."""
|
| try:
|
| import torch.distributed as dist
|
|
|
| if dist.is_available() and dist.is_initialized():
|
| return dist.get_rank()
|
| except Exception:
|
| pass
|
| return 0
|
|
|
|
|
|
|
|
|
|
|
| def _iter_txt(self, rng: random.Random, files: List[str] = None) -> Iterator[str]:
|
| """Stream lines from .txt files, looping if infinite."""
|
| files = files if files is not None else self.files
|
| while True:
|
| for fp in files:
|
| with open(fp, "r", encoding=self.encoding) as f:
|
| for line in f:
|
| s = line.strip()
|
| if not s or s.startswith("#"):
|
| continue
|
| yield s
|
| if not self.infinite:
|
| break
|
|
|
| def _iter_csv_file(self, fp: str) -> Iterator[str]:
|
|
|
| cs = self.csvspec
|
| with open(fp, "r", encoding=self.encoding, newline="") as f:
|
| reader = csv.DictReader(f)
|
|
|
| if cs.caption_field not in reader.fieldnames:
|
| raise KeyError(
|
| f"CSV missing caption field '{cs.caption_field}'. "
|
| f"Got fields: {reader.fieldnames[:20]}..."
|
| )
|
|
|
| for row in reader:
|
| cap = (row.get(cs.caption_field) or "").strip()
|
| if not cap:
|
| continue
|
|
|
|
|
| if cs.min_clarity is not None:
|
| v = _maybe_float(row.get(cs.clarity_field, ""))
|
| if v is None or v < cs.min_clarity:
|
| continue
|
| if cs.min_aesthetic is not None:
|
| v = _maybe_float(row.get(cs.aesthetic_field, ""))
|
| if v is None or v < cs.min_aesthetic:
|
| continue
|
|
|
| yield cap
|
|
|
| def _iter_csv(self, rng: random.Random, files: List[str] = None) -> Iterator[str]:
|
| """Iterate CSV shards; approximate shuffle with buffer if requested."""
|
| files = list(files if files is not None else self.files)
|
|
|
| while True:
|
| if self.shuffle_files:
|
| rng.shuffle(files)
|
|
|
| if self.shuffle_buffer > 0:
|
| buf: List[str] = []
|
|
|
| for fp in files:
|
| for cap in self._iter_csv_file(fp):
|
| buf.append(cap)
|
| if len(buf) >= self.shuffle_buffer:
|
| j = rng.randrange(len(buf))
|
| yield buf.pop(j)
|
|
|
| while buf:
|
| j = rng.randrange(len(buf))
|
| yield buf.pop(j)
|
| else:
|
| for fp in files:
|
| yield from self._iter_csv_file(fp)
|
|
|
| if not self.infinite:
|
| break
|
|
|
| def __iter__(self) -> Iterator[str]:
|
|
|
| dist_files = self._get_dist_files()
|
| rank = self._get_rank()
|
|
|
|
|
|
|
| wi = torch.utils.data.get_worker_info()
|
| worker_id = 0 if wi is None else wi.id
|
| rng = random.Random(self.seed + 1009 * worker_id + 97 * rank)
|
|
|
| if self.mode == "txt":
|
| yield from self._iter_txt(rng, dist_files)
|
| else:
|
| yield from self._iter_csv(rng, dist_files)
|
|
|
|
|
| def make_collate_fn(tokenizer, max_prompt_length: int, device: torch.device):
|
| """Tokenize List[str] -> [B, L] tensor.
|
|
|
| IMPORTANT: returns a CPU tensor regardless of ``device`` argument.
|
| Move to GPU inside the training step (not in the dataloader worker).
|
| """
|
| tok_kwargs = {
|
| "max_length": max_prompt_length,
|
| "padding": "max_length",
|
| "padding_side": "left",
|
| "truncation": True,
|
| "return_tensors": "pt",
|
| }
|
|
|
| def collate_fn(prompts: List[str]) -> torch.Tensor:
|
| return tokenizer(prompts, **tok_kwargs).input_ids
|
|
|
| return collate_fn
|
|
|
|
|
| class InfiniteDataLoader:
|
| """Wraps a DataLoader and cycles indefinitely (works for both map/iter datasets)."""
|
|
|
| def __init__(self, dataloader):
|
| self.dataloader = dataloader
|
| self._iter = iter(dataloader)
|
|
|
| def __next__(self):
|
| try:
|
| return next(self._iter)
|
| except StopIteration:
|
| self._iter = iter(self.dataloader)
|
| return next(self._iter)
|
|
|
| def __iter__(self):
|
| return self
|
|
|