| from typing import Optional, Iterator, Callable, Any |
|
|
| import torch |
| from datasets import load_dataset, concatenate_datasets |
| from transformers import AutoTokenizer |
|
|
|
|
| def load_text_dataset(tokenizer: AutoTokenizer, |
| kind: str, |
| path: str, |
| name: Optional[str]=None, |
| data_dir: Optional[str]=None, |
| data_files: Optional[str]=None, |
| keep_in_memory: bool=False, |
| revision: Optional[str]=None, |
| split: str='train', |
| num_proc: Optional[int]=None, |
| format: Optional[Callable|str]=None) -> Any: |
| assert isinstance(format, str) or callable(format), f'{path=} {format=}' |
| assert kind == 'base' |
|
|
| dataset = load_dataset(path=path, |
| name=name, |
| data_dir=data_dir, |
| data_files=data_files, |
| keep_in_memory=keep_in_memory, |
| revision=revision, |
| split=split, |
| trust_remote_code=True, |
| num_proc=num_proc) |
|
|
| EOS_TOKEN = tokenizer.eos_token |
|
|
| def format_dataset(batch): |
| nonlocal EOS_TOKEN |
| nonlocal format |
| texts: list = [] |
| rows = [dict(zip(batch.keys(), values)) for values in zip(*batch.values())] |
|
|
| if callable(format): |
| for row in rows: |
| |
| text = format(row) |
|
|
| if not text: |
| text = '[NONE]' |
|
|
| text += EOS_TOKEN |
| texts.append(text) |
| else: |
| for row in rows: |
| |
| text = format.format(**row) |
|
|
| if not text: |
| text = '[NONE]' |
|
|
| text += EOS_TOKEN |
| texts.append(text) |
|
|
| return {'text': texts} |
|
|
| dataset = dataset.map(format_dataset, batched=True) |
| return dataset |
|
|
|
|
| def load_chat_dataset(tokenizer: AutoTokenizer, |
| kind: str, |
| path: str, |
| name: Optional[str]=None, |
| data_dir: Optional[str]=None, |
| data_files: Optional[str]=None, |
| keep_in_memory: bool=False, |
| revision: Optional[str]=None, |
| split: str='train', |
| num_proc: Optional[int]=None, |
| field: Optional[str]=None, |
| transform: Optional[Callable]=None) -> Any: |
| assert kind == 'instruct' |
|
|
| dataset = load_dataset(path=path, |
| name=name, |
| data_dir=data_dir, |
| data_files=data_files, |
| keep_in_memory=keep_in_memory, |
| revision=revision, |
| split=split, |
| trust_remote_code=True, |
| num_proc=num_proc) |
|
|
| EOS_TOKEN = tokenizer.eos_token |
|
|
| def format_dataset(batch): |
| nonlocal EOS_TOKEN |
| nonlocal tokenizer |
| nonlocal field |
| nonlocal transform |
| texts: list = [] |
| rows = [dict(zip(batch.keys(), values)) for values in zip(*batch.values())] |
|
|
| if callable(transform): |
| for row in rows: |
| if field: |
| messages = transform(row[field]) |
| else: |
| messages = transform(row) |
|
|
| text = tokenizer.apply_chat_template(messages, tokenize=False) |
| text += EOS_TOKEN |
| texts.append(text) |
| else: |
| for row in rows: |
| if field: |
| messages = row[field] |
| else: |
| raise ValueError(field) |
|
|
| text = tokenizer.apply_chat_template(messages, tokenize=False) |
| text += EOS_TOKEN |
| texts.append(text) |
|
|
| return {'text': texts} |
|
|
| dataset = dataset.map(format_dataset, batched=True) |
| return dataset |
|
|