| |
|
|
| from __future__ import annotations |
|
|
| import copy |
| import pickle |
| from copy import deepcopy |
| from dataclasses import dataclass |
| from typing import Any, Callable, Dict, Iterable, List, Optional, Union |
|
|
| import datasets |
| import numpy as np |
| import torch |
| from datasets import Dataset, IterableDataset, interleave_datasets, load_dataset |
| from datasets.iterable_dataset import ShufflingConfig |
| from torch.distributed.checkpoint.stateful import Stateful |
| from torchdata.stateful_dataloader import StatefulDataLoader |
| from transformers import PreTrainedTokenizer |
|
|
| from torchtitan.tools import utils |
| from torchtitan.tools.logging import logger |
|
|
| datasets.logging.set_verbosity_info() |
|
|
| class BufferShuffledIterableDataset(IterableDataset): |
| def __init__( |
| self, |
| dataset: Dataset, |
| tokenizer: PreTrainedTokenizer, |
| seq_len: int = 2048, |
| rank: int = 0, |
| world_size: int = 1, |
| buffer_size: int = 1024, |
| ) -> BufferShuffledIterableDataset: |
| self.dataset = dataset |
| self.tokenizer = tokenizer |
|
|
| self.data = dataset.shard(world_size, rank) |
| self.seq_len = seq_len |
|
|
| self.rank = rank |
| self.world_size = world_size |
| self.buffer_size = buffer_size |
|
|
| if tokenizer.vocab_size < torch.iinfo(torch.uint16).max: |
| self.dtype = torch.uint16 |
| elif tokenizer.vocab_size < torch.iinfo(torch.uint32).max: |
| self.dtype = torch.uint32 |
| else: |
| self.dtype = torch.uint64 |
| self.states = None |
| self.buffer = torch.tensor([], dtype=self.dtype) |
| self.tokens = [] |
| self.rand_id = 0 |
| self.token_id = 0 |
| self.rng_state = None |
| self._epoch = 0 |
|
|
| def __iter__(self): |
| g = torch.Generator() |
| g.manual_seed(self._epoch + self.rank) |
| if self.rng_state is not None: |
| g.set_state(self.rng_state) |
|
|
| rand_it = self.randint(0, self.buffer_size, g=g) |
| if self.states is not None: |
| self.data.load_state_dict(self.states) |
|
|
| |
| n_tokens = self.buffer_size * self.seq_len |
|
|
| while True: |
| for sample in self.tokenize(self.data): |
| |
| self.tokens += sample |
| |
| |
| if len(self.buffer) == 0 and len(self.tokens) >= n_tokens: |
| self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1) |
| self.tokens = self.tokens[n_tokens:] |
| if len(self.buffer) == self.buffer_size: |
| yield from self.sample(rand_it) |
|
|
| n_chunks = len(self.tokens) // self.seq_len |
| |
| if n_chunks > 0: |
| n_tokens = n_chunks * self.seq_len |
| indices = torch.randperm(n_chunks, generator=g).tolist() |
| self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1) |
| self.tokens = self.tokens[n_tokens:] |
| for i in indices: |
| yield {'input_ids': self.buffer[i]} |
|
|
| def tokenize(self, data, batch_size: int = 64): |
| texts, states = [], [] |
| for sample in data: |
| texts.append(sample['text']) |
| states.append(self.data.state_dict()) |
| if len(texts) == batch_size: |
| for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']): |
| self.states = s |
| yield tokenized |
| texts, states = [], [] |
| if len(texts) > 0: |
| for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']): |
| self.states = s |
| yield tokenized |
|
|
| def sample(self, indices): |
| n_tokens = (len(self.tokens) // self.seq_len) * self.seq_len |
| while self.token_id < n_tokens: |
| i = next(indices) |
| start, end = self.token_id, self.token_id + self.seq_len |
| self.token_id += self.seq_len |
| yield {'input_ids': self.buffer[i].to(torch.long)} |
| self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype) |
| self.token_id = 0 |
| self.tokens = self.tokens[n_tokens:] |
|
|
| def randint(self, low: int, high: int, buffer_size: int = 1024, g: torch.Generator = torch.Generator()) -> Iterable[int]: |
| indices = torch.empty(buffer_size, dtype=torch.long) |
| while True: |
| |
| self.rng_state = g.get_state() |
| indices = torch.randint(low, high, (buffer_size,), out=indices, generator=g) |
| for i in indices[self.rand_id:].tolist(): |
| self.rand_id += 1 |
| yield i |
| self.rand_id = 0 |
|
|
| def set_epoch(self, epoch): |
| self._epoch = epoch |
| if hasattr(self.dataset, 'set_epoch'): |
| self.dataset.set_epoch(epoch) |
|
|
| def state_dict(self): |
| return { |
| 'states': self.states, |
| 'buffer': self.buffer.clone(), |
| 'tokens': deepcopy(self.tokens), |
| 'rand_id': self.rand_id, |
| 'token_id': self.token_id, |
| 'rng_state': self.rng_state, |
| 'epoch': self._epoch, |
| } |
|
|
| def load_state_dict(self, state_dict): |
| self.states = state_dict['states'] |
| self.buffer = state_dict['buffer'].clone() |
| self.tokens = deepcopy(state_dict['tokens']) |
| self.rand_id = state_dict['rand_id'] |
| self.token_id = state_dict['token_id'] |
| self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None |
| self._epoch = state_dict['epoch'] |
|
|
|
|
| class OnlineTokenizedIterableDataset(IterableDataset): |
| def __init__( |
| self, dataset: Dataset, tokenizer: PreTrainedTokenizer, seq_len: int = 2048, rank: int = 0, world_size: int = 1 |
| ) -> OnlineTokenizedIterableDataset: |
| self.dataset = dataset |
| self.tokenizer = tokenizer |
|
|
| self.data = dataset.shard(world_size, rank) |
| self.seq_len = seq_len |
| self.rank = rank |
| self.world_size = world_size |
|
|
| self.states = None |
| self.tokens = [] |
|
|
| def __iter__(self): |
| if self.states is not None: |
| self.data.load_state_dict(self.states) |
|
|
| while True: |
| for sample in self.tokenize(self.data): |
| |
| self.tokens += sample |
|
|
| while len(self.tokens) >= self.seq_len: |
| input_ids = torch.tensor(self.tokens[:self.seq_len], dtype=torch.long) |
| self.tokens = self.tokens[self.seq_len:] |
| yield {'input_ids': input_ids} |
|
|
| def tokenize(self, data, buffer_size: int = 64): |
| buffer, states = [], [] |
| for sample in data: |
| if sample.get('text', None) is not None: |
| buffer.append(sample['text']) |
| elif sample.get('content', None) is not None: |
| buffer.append(sample['content']) |
| else: |
| raise ValueError(f"No 'text' or 'content' field found in sample:\n{sample}") |
| states.append(self.data.state_dict()) |
| if len(buffer) == buffer_size: |
| for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']): |
| self.states = s |
| yield tokenized |
| buffer, states = [], [] |
| if len(buffer) > 0: |
| for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']): |
| self.states = s |
| yield tokenized |
|
|
| def state_dict(self): |
| return {'states': self.states, 'tokens': deepcopy(self.tokens)} |
|
|
| def load_state_dict(self, state_dict): |
| self.states = state_dict['states'] |
| self.tokens = deepcopy(state_dict['tokens']) |
|
|
|
|
| class BufferShuffledExamplesIterable(datasets.iterable_dataset.BufferShuffledExamplesIterable): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def _init_state_dict(self) -> dict: |
| self._state_dict = self.ex_iterable._init_state_dict() |
| self._state_dict['mem_buffer'] = ([],) |
| self._state_dict['bit_generator_state'] = self.generator.bit_generator.state |
| self._state_dict['bit_generator_index_offset'] = 0 |
| self._state_dict['bit_generator_index_offset_shuffle'] = 0 |
| return self._state_dict |
|
|
| def __iter__(self): |
| buffer_size = self.buffer_size |
| rng = deepcopy(self.generator) |
| |
| mem_buffer = self._state_dict['mem_buffer'][0] |
| |
| index_offset = self._state_dict['bit_generator_index_offset'] if self._state_dict else 0 |
| if self._state_dict: |
| rng.bit_generator.state = self._state_dict['bit_generator_state'] |
| indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size=buffer_size) |
| |
| for _ in range(index_offset): |
| i = next(indices_iterator) |
|
|
| for x in self.ex_iterable: |
| if len(mem_buffer) < buffer_size: |
| mem_buffer.append(x) |
| else: |
| i = next(indices_iterator) |
| index_offset = (index_offset + 1) % buffer_size |
| if self._state_dict: |
| self._state_dict['bit_generator_index_offset'] = index_offset |
| if index_offset == 0: |
| self._state_dict['bit_generator_state'] = rng.bit_generator.state |
| selected = mem_buffer[i] |
| mem_buffer[i] = x |
| yield selected |
|
|
| index_offset = self._state_dict['bit_generator_index_offset_shuffle'] if self._state_dict else 0 |
| if self._state_dict: |
| rng.bit_generator.state = self._state_dict['bit_generator_state'] |
|
|
| |
| for i in rng.permutation(len(mem_buffer))[index_offset:].tolist(): |
| index_offset = index_offset + 1 |
| if self._state_dict: |
| self._state_dict['bit_generator_index_offset_shuffle'] = index_offset |
| yield mem_buffer[i] |
|
|
| def shuffle_data_sources(self, generator: np.random.Generator) -> BufferShuffledExamplesIterable: |
| """Shuffle the wrapped examples iterable as well as the shuffling buffer.""" |
| return BufferShuffledExamplesIterable( |
| self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator |
| ) |
|
|
| def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> BufferShuffledExamplesIterable: |
| """Keep only the requested shard.""" |
| return BufferShuffledExamplesIterable( |
| self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), |
| buffer_size=self.buffer_size, |
| generator=self.generator, |
| ) |
|
|
| def load_state_dict(self, state_dict: dict) -> dict: |
| def _inner_load_state_dict(state, new_state): |
| if new_state is not None and isinstance(state, dict): |
| for key in new_state: |
| state[key] = _inner_load_state_dict(state[key], new_state[key]) |
| return state |
| elif new_state is not None and isinstance(state, list): |
| for i in range(len(state)): |
| state[i] = _inner_load_state_dict(state[i], new_state[i]) |
| return state |
| return new_state |
|
|
| return _inner_load_state_dict(self._state_dict, state_dict) |
|
|
|
|
| def shuffle( |
| dataset: IterableDataset, |
| seed: int = 42, |
| generator: np.random.Generator = None, |
| buffer_size: int = 1024, |
| ): |
| generator = np.random.default_rng(seed) if generator is None else deepcopy(generator) |
| return IterableDataset( |
| ex_iterable=BufferShuffledExamplesIterable(dataset._ex_iterable, buffer_size=buffer_size, generator=generator), |
| info=dataset._info.copy(), |
| split=dataset._split, |
| formatting=dataset._formatting, |
| shuffling=ShufflingConfig(generator=generator, _original_seed=seed), |
| distributed=copy.deepcopy(dataset._distributed), |
| token_per_repo_id=dataset._token_per_repo_id, |
| ) |
|
|
|
|
| @dataclass |
| class DataCollatorForLanguageModeling: |
| """ |
| Data collator used for language modeling. Inputs are dynamically padded if `varlen=False`. |
| If `varlen=True`, sequences are expected to be concatenated, and labels match inputs. |
| |
| Args: |
| tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): |
| The tokenizer used for encoding the data. |
| context_len (`int`, optional): |
| When `varlen=True`, sequences longer than this length within a document |
| (as determined by `cu_seqlens`) will be further chunked. |
| varlen (`bool`): |
| Whether to handle variable length concatenated sequences (`True`) or padded batches (`False`). |
| |
| Returns: |
| A dictionary with the following keys: |
| - `input_ids`: Tensor of input IDs. Shape `[batch_size, seq_len]` if `varlen=False`, `[1, total_len]` if `varlen=True`. |
| - `labels`: Tensor of labels. Shape matches `input_ids`. Padding positions are masked with -100 if `varlen=False`. |
| - `attention_mask`: Tensor indicating non-padding tokens (only if `varlen=False`). Shape matches `input_ids`. |
| - `cu_seqlens`: Tensor of cumulative sequence lengths (only if `varlen=True`). Shape `[1, num_sequences + 1]`. |
| |
| NOTE: When `varlen=True`, the `batch_size` must be 1. |
| """ |
|
|
| tokenizer: PreTrainedTokenizer |
| context_len: Optional[int] = None |
| varlen: bool = False |
|
|
| def __call__(self, examples: List[Union[List[int], Dict[str, Any]]]) -> Dict[str, Any]: |
| if not isinstance(examples[0], Dict): |
| examples = [{'input_ids': example} for example in examples] |
|
|
| def tensorize(example: Dict[str, Any]) -> Dict[str, Any]: |
| tensorized = {} |
| for key in ['input_ids', 'cu_seqlens']: |
| if key not in example: |
| continue |
| if isinstance(example[key], List): |
| tensorized[key] = torch.tensor(example[key], dtype=torch.long) |
| elif isinstance(example[key], np.ndarray): |
| tensorized[key] = torch.from_numpy(example[key]) |
| else: |
| tensorized[key] = example[key] |
| return tensorized |
|
|
| examples = list(map(tensorize, examples)) |
|
|
| if not self.varlen: |
| |
| length_of_first = examples[0]['input_ids'].size(0) |
| needs_padding = not all(example['input_ids'].size(0) == length_of_first for example in examples) |
|
|
| if needs_padding: |
| |
| if self.tokenizer.pad_token_id is None: |
| raise ValueError( |
| f'You are attempting to pad samples but the tokenizer you are using ' |
| f'({self.tokenizer.__class__.__name__}) does not have a pad token.' |
| ) |
| |
| batch = self.tokenizer.pad(examples, return_tensors='pt', return_attention_mask=True) |
| else: |
| |
| input_ids = torch.stack([example['input_ids'] for example in examples], dim=0) |
| batch = { |
| 'input_ids': input_ids, |
| |
| 'attention_mask': torch.ones_like(input_ids), |
| } |
|
|
| |
| labels = batch['input_ids'].clone() |
| |
| if 'attention_mask' in batch: |
| labels[batch['attention_mask'] == 0] = -100 |
| batch['labels'] = labels |
|
|
| else: |
| |
| if len(examples) > 1: |
| raise ValueError('The batch size must be 1 for inputs with variable lengths (varlen=True).') |
|
|
| batch = {'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)} |
|
|
| |
| if 'cu_seqlens' in examples[0]: |
| batch['cu_seqlens'] = ( |
| torch.cat([example['cu_seqlens'] for example in examples], dim=0).unsqueeze(0).to(dtype=torch.int32) |
| ) |
| else: |
| |
| |
| if self.tokenizer.bos_token_id is not None: |
| cu_seqlens = [] |
| |
| if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id: |
| cu_seqlens.append(torch.tensor([0], device=batch['input_ids'].device)) |
| |
| bos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1] |
| |
| if bos_positions.numel() == 0 and len(cu_seqlens) > 0: |
| cu_seqlens.append(bos_positions.to(cu_seqlens[0].device)) |
| elif bos_positions.numel() > 0: |
| cu_seqlens.append(bos_positions) |
| |
| cu_seqlens.append( |
| torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device) |
| ) |
| |
| cu_seqlens = [t for t in cu_seqlens if t.numel() > 0] |
| if not cu_seqlens: |
| batch['cu_seqlens'] = torch.tensor( |
| [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device |
| ) |
| else: |
| batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32) |
|
|
| |
| elif self.tokenizer.eos_token_id is not None: |
| cu_seqlens = [torch.tensor([0], device=batch['input_ids'].device)] |
| |
| eos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1 |
| |
| if eos_positions.numel() > 0: |
| cu_seqlens.append(eos_positions) |
| |
| if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id: |
| |
| if eos_positions.numel() == 0 or eos_positions[-1] != batch['input_ids'].size(1): |
| cu_seqlens.append( |
| torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device) |
| ) |
| |
| cu_seqlens = [t for t in cu_seqlens if t.numel() > 0] |
| if not cu_seqlens: |
| batch['cu_seqlens'] = torch.tensor( |
| [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device |
| ) |
| else: |
| batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32) |
| |
| else: |
| raise ValueError( |
| 'For varlen=True without precomputed cu_seqlens, the tokenizer must have either a bos_token_id ' |
| 'or an eos_token_id defined to act as sequence separators.' |
| ) |
|
|
| |
| if batch['cu_seqlens'].numel() < 2: |
| raise ValueError(f'Calculated cu_seqlens must have at least start and end: {batch["cu_seqlens"]}') |
| if not torch.all(batch['cu_seqlens'][1:] >= batch['cu_seqlens'][:-1]): |
| raise ValueError(f'Calculated cu_seqlens are not monotonically increasing: {batch["cu_seqlens"]}') |
| if batch['cu_seqlens'][0] != 0: |
| raise ValueError(f'Calculated cu_seqlens do not start at 0: {batch["cu_seqlens"]}') |
| if batch['cu_seqlens'][-1] != batch['input_ids'].size(1): |
| |
| if not (batch['cu_seqlens'].tolist() == [0, 0] and batch['input_ids'].size(1) == 0): |
| raise ValueError( |
| f'Calculated cu_seqlens do not end at total length {batch["input_ids"].size(1)}: ' |
| f'{batch["cu_seqlens"]}' |
| ) |
|
|
| |
| if self.context_len is not None: |
| |
| bos = batch['cu_seqlens'][:-1].tolist() |
| eos = batch['cu_seqlens'][1:].tolist() |
| |
| split_boundaries = [] |
| for i, j in zip(bos, eos): |
| if i < j: |
| split_boundaries.append(torch.arange(i, j, self.context_len, device=batch['input_ids'].device)) |
| |
| final_end_point = torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device) |
| |
| if not split_boundaries: |
| batch['cu_seqlens'] = torch.tensor([0, 0], dtype=torch.int32, device=batch['input_ids'].device) |
| else: |
| batch['cu_seqlens'] = torch.cat(split_boundaries + [final_end_point]).to(dtype=torch.int32) |
| |
| batch['cu_seqlens'] = torch.unique(batch['cu_seqlens']) |
|
|
| |
| labels = batch['input_ids'].clone() |
| batch['labels'] = labels |
|
|
| return batch |
|
|
|
|
| class ParallelAwareDataLoader(StatefulDataLoader, Stateful): |
| """ |
| A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank. |
| """ |
|
|
| def __init__( |
| self, |
| rank: int, |
| dataset: IterableDataset, |
| batch_size: int, |
| collate_fn: Callable, |
| num_workers: int = 0, |
| pin_memory: bool = False, |
| prefetch_factor: int = 2, |
| persistent_workers: bool = False, |
| snapshot_every_n_steps: Optional[int] = 1, |
| ): |
| super().__init__( |
| dataset=dataset, |
| batch_size=batch_size, |
| collate_fn=collate_fn, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| prefetch_factor=prefetch_factor, |
| persistent_workers=persistent_workers, |
| snapshot_every_n_steps=snapshot_every_n_steps, |
| ) |
| self.rank = rank |
|
|
| def state_dict(self) -> Dict[str, Any]: |
| |
| return {f'rank_{self.rank}': pickle.dumps(super().state_dict())} |
|
|
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
| |
| if not state_dict: |
| return |
|
|
| if f'rank_{self.rank}' not in state_dict: |
| logger.warning(f'DataLoader state is empty for dp rank {self.rank}, expected key rank_{self.rank}') |
| return |
| super().load_state_dict(pickle.loads(state_dict[f'rank_{self.rank}'])) |
|
|
|
|
| def build_dataset( |
| dataset: str, |
| dataset_name: str = None, |
| dataset_split: str = 'train', |
| data_dir: str = None, |
| data_files: str = None, |
| data_probs: List[float] = None, |
| streaming: bool = False, |
| dp_degree: Optional[int] = None, |
| num_workers: int = 32, |
| seed: Optional[int] = None, |
| ) -> IterableDataset: |
| color = utils.Color |
| min_num_shards = dp_degree * num_workers if dp_degree else None |
| if len(dataset.split(',')) == 1: |
| dataset = load_dataset( |
| path=dataset, |
| name=dataset_name, |
| split=dataset_split, |
| data_dir=data_dir, |
| data_files=data_files, |
| |
| streaming=streaming, |
| num_proc=num_workers if not streaming else None, |
| ) |
| logger.info(f"Shuffling the dataset with seed {seed}") |
| if not streaming: |
| |
| if seed is not None: |
| dataset = dataset.shuffle(seed=seed) |
| if min_num_shards is not None: |
| dataset = dataset.to_iterable_dataset(num_shards=min_num_shards) |
| else: |
| if min_num_shards is not None and dataset.num_shards < min_num_shards: |
| logger.warning( |
| f"{color.red}" |
| f"Dataset {dataset} has insufficient shards ({dataset.num_shards}). " |
| f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × " |
| f"{num_workers} dataloader workers. " |
| f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards." |
| f"{color.reset}" |
| ) |
| dataset = load_dataset( |
| path=dataset, |
| name=dataset_name, |
| split=dataset_split, |
| data_dir=data_dir, |
| data_files=data_files, |
| |
| streaming=False, |
| num_proc=num_workers, |
| ) |
| if seed is not None: |
| dataset = dataset.shuffle(seed=seed) |
| dataset = dataset.to_iterable_dataset(num_shards=min_num_shards) |
| else: |
| if seed is not None: |
| dataset = shuffle(dataset, seed=seed) |
| else: |
| datasets = dataset.split(",") |
| if dataset_name is not None: |
| dataset_names = [ |
| name or None for name in dataset_name.split(",") |
| ] |
| assert len(dataset_names) == len(datasets), ( |
| "The number of dataset names must match the number of datasets" |
| ) |
| else: |
| dataset_names = [None] * len(datasets) |
| if dataset_split is not None: |
| dataset_splits = [split or "train"for split in dataset_split.split(",")] |
| assert len(dataset_splits) == len(datasets), ( |
| "The number of dataset splits must match the number of datasets" |
| ) |
| else: |
| dataset_splits = ["train"] * len(datasets) |
| if data_dir is not None: |
| data_dirs = [ |
| data_dir or None for data_dir in data_dir.split(",") |
| ] |
| assert len(data_dirs) == len(datasets), ( |
| "The number of data dirs must match the number of datasets" |
| ) |
| else: |
| data_dirs = [None] * len(datasets) |
| if data_files is not None: |
| data_files = data_files.split(",") |
| assert len(data_files) == len(datasets), ( |
| "The number of data files must match the number of datasets" |
| ) |
| else: |
| data_files = [None] * len(datasets) |
| if data_probs is not None: |
| data_probs = [float(p) for p in data_probs.split(",")] |
| assert len(data_probs) == len(datasets), ( |
| "The number of data probabilities must match the number of datasets" |
| ) |
| else: |
| raise ValueError( |
| "Data sampling probabilities are required if using multiple datasets" |
| ) |
|
|
| subsets = [] |
| for i, prob in enumerate(data_probs): |
| subset = load_dataset( |
| path=datasets[i], |
| name=dataset_names[i], |
| split=dataset_splits[i], |
| data_dir=data_dirs[i], |
| data_files=data_files[i], |
| |
| streaming=streaming, |
| num_proc=( |
| num_workers |
| if not streaming |
| else None |
| ), |
| ) |
| logger.info( |
| f"Subset {color.cyan}{datasets[i]}" |
| + (f":{dataset_names[i]} " if dataset_names[i] else " ") |
| + f"(p = {prob:.3f}){color.reset}:\n" |
| + f"{subset}" |
| ) |
|
|
| logger.info(f"Shuffling the dataset with seed {seed}") |
| if not streaming: |
| |
| if seed is not None: |
| subset = subset.shuffle(seed=seed) |
| if min_num_shards is not None: |
| subset = subset.to_iterable_dataset(num_shards=min_num_shards) |
| else: |
| if min_num_shards is not None and subset.num_shards < min_num_shards: |
| logger.warning( |
| f"{color.red}" |
| f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). " |
| f"Need {min_num_shards} shards minimum for desired data parallel workers × " |
| f"{num_workers} dataloader workers. " |
| f"Resharding dataset to {min_num_shards} shards and disabling streaming mode." |
| f"{color.reset}" |
| ) |
| |
| |
| subset = load_dataset( |
| path=datasets[i], |
| name=dataset_names[i], |
| split=dataset_splits[i], |
| data_dir=data_dirs[i], |
| data_files=data_files[i], |
| |
| streaming=False, |
| num_proc=num_workers, |
| ) |
| if seed is not None: |
| subset = subset.shuffle(seed=seed) |
| subset = subset.to_iterable_dataset(num_shards=min_num_shards) |
| else: |
| |
| if seed is not None: |
| subset = shuffle(subset, seed=seed, buffer_size=max(128, 1024 // len(datasets))) |
|
|
| if "text" in subset.column_names: |
| subset = subset.select_columns("text") |
| elif "content" in subset.column_names: |
| subset = subset.select_columns("content") |
| else: |
| raise ValueError( |
| f"Subset {datasets[i]} has no 'text' or 'content' column" |
| ) |
| subsets.append(subset) |
|
|
| logger.info( |
| f"Interleaving {len(subsets)} datasets with probabilities {data_probs}" |
| ) |
| dataset = interleave_datasets( |
| datasets=subsets, |
| probabilities=data_probs, |
| stopping_strategy="all_exhausted", |
| seed=seed, |
| ) |
| logger.info(f"{dataset}") |
| return dataset |
|
|
|
|
| def build_dataloader( |
| dataset: IterableDataset, |
| tokenizer: PreTrainedTokenizer, |
| rank: int, |
| world_size: int, |
| batch_size: int, |
| seq_len: int, |
| context_len: Optional[int] = None, |
| varlen: bool = False, |
| num_workers: int = 0, |
| pin_memory: bool = False, |
| persistent_workers: bool = False, |
| snapshot_every_n_steps: Optional[int] = 1, |
| ): |
| dataset = OnlineTokenizedIterableDataset( |
| dataset=dataset, tokenizer=tokenizer, seq_len=seq_len, rank=rank, world_size=world_size |
| ) |
| return ParallelAwareDataLoader( |
| rank=rank, |
| dataset=dataset, |
| batch_size=batch_size, |
| collate_fn=DataCollatorForLanguageModeling(tokenizer=tokenizer, context_len=context_len, varlen=varlen), |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| persistent_workers=persistent_workers, |
| snapshot_every_n_steps=snapshot_every_n_steps, |
| ) |
|
|