|
|
| import torch |
| from torch.utils.data import DataLoader, Dataset |
| from datasets import load_dataset |
| from transformers import AutoTokenizer |
|
|
| def train_tokenizer(texts, vocab_size=50000, min_frequency=2): |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer = tokenizer.train_new_from_iterator(texts, vocab_size=vocab_size, min_frequency=min_frequency) |
| if tokenizer.pad_token is None: |
| tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| tokenizer.save_pretrained("./tokenizer") |
| return tokenizer |
|
|
| def load_tokenizer(): |
| tokenizer = AutoTokenizer.from_pretrained("./tokenizer") |
| if tokenizer.pad_token is None: |
| tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| return tokenizer |
|
|
| class TextDataset(Dataset): |
| def __init__(self, texts, tokenizer, max_length): |
| self.texts = texts |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __len__(self): |
| return len(self.texts) |
|
|
| def __getitem__(self, idx): |
| text = self.texts[idx] |
| encodings = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length) |
| return torch.tensor(encodings['input_ids']) |
|
|
| def get_dataloader(dataset_name, config_name, tokenizer, max_length, batch_size): |
| dataset = load_dataset(dataset_name, config_name) |
| texts = dataset['train']['text'][:50] |
| dataset = TextDataset(texts, tokenizer, max_length) |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
| return dataloader |
|
|