| import gc |
| from typing import Optional, Iterator, Callable |
|
|
| import torch |
| from datasets import load_dataset |
| from litgpt.tokenizer import Tokenizer |
| from transformers import AutoTokenizer |
|
|
|
|
| def batch_text_iterator(kind: str, |
| path: str, |
| name: Optional[str]=None, |
| data_dir: Optional[str]=None, |
| data_files: Optional[str]=None, |
| keep_in_memory: bool=False, |
| revision: Optional[str]=None, |
| split: str='train', |
| num_proc: Optional[int]=None, |
| format: Optional[Callable|str]=None) -> Iterator[str]: |
| assert isinstance(format, str) or callable(format), f'{path=} {format=}' |
| assert kind == 'base' |
|
|
| dataset = load_dataset(path=path, |
| name=name, |
| data_dir=data_dir, |
| data_files=data_files, |
| keep_in_memory=keep_in_memory, |
| revision=revision, |
| split=split, |
| trust_remote_code=True, |
| num_proc=num_proc) |
|
|
| if callable(format): |
| for row in dataset: |
| text = format(row) |
| yield text |
| else: |
| for row in dataset: |
| text = format.format(**row) |
| yield text |
|
|
| del dataset |
| gc.collect() |
|
|
|
|
| def batch_chat_iterator(kind: str, |
| path: str, |
| name: Optional[str]=None, |
| data_dir: Optional[str]=None, |
| data_files: Optional[str]=None, |
| keep_in_memory: bool=False, |
| revision: Optional[str]=None, |
| split: str='train', |
| num_proc: Optional[int]=None, |
| field: Optional[str]=None, |
| transform: Optional[Callable]=None) -> Iterator[list[dict[str, str]]]: |
| assert kind == 'instruct' |
|
|
| dataset = load_dataset(path=path, |
| name=name, |
| data_dir=data_dir, |
| data_files=data_files, |
| keep_in_memory=keep_in_memory, |
| revision=revision, |
| split=split, |
| trust_remote_code=True, |
| num_proc=num_proc) |
|
|
| if callable(transform): |
| for row in dataset: |
| if field: |
| messages = transform(row[field]) |
| else: |
| messages = transform(row) |
|
|
| yield messages |
| else: |
| for row in dataset: |
| if field: |
| messages = row[field] |
| else: |
| raise ValueError(field) |
|
|
| yield messages |
|
|
| del dataset |
| gc.collect() |
|
|
|
|
| def tokenize_text_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: |
| for text in batch_text_iterator(**dataset_config): |
| text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True) |
| yield text_ids |
|
|
|
|
| def tokenize_chat_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: |
| for messages in batch_chat_iterator(**dataset_config): |
| text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False) |
| text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False) |
| yield text_ids |
|
|
|
|
| def tokenize_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: |
| if dataset_config['kind'] == 'base': |
| for text in batch_text_iterator(**dataset_config): |
| try: |
| text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True) |
| except Exception as e: |
| print(f'Skip base raw: {e=} {type(text)=} {text=}') |
| continue |
|
|
| yield text_ids |
| elif dataset_config['kind'] == 'instruct': |
| for messages in batch_chat_iterator(**dataset_config): |
| try: |
| text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False) |
| text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False) |
| except Exception as e: |
| print(f'Skip instruct row: {e=} {type(messages)=} {messages=}') |
| continue |
|
|
| yield text_ids |
| else: |
| raise ValueError(dataset_config['kind']) |
|
|