| import os |
| import time |
| import torch |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments |
|
|
| class GptHumorTrainer: |
|
|
| def __init__(self, silent=False) -> None: |
| start_time = time.perf_counter() |
| self.tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2") |
| self.model = GPT2LMHeadModel.from_pretrained(self.local_file_path("SaveState")) |
| self.model.eval() |
| if not silent: |
| print(f"Model Loading Took {time.perf_counter()-start_time} Seconds") |
|
|
| def local_file_path(self, path): |
| return os.path.join(os.path.dirname(os.path.abspath(__file__)), path) |
|
|
| def train(self, train_file, epochs=3): |
| device = torch.device("cpu") |
| self.model.to(device) |
|
|
| |
| train_dataset = TextDataset( |
| tokenizer=self.tokenizer, |
| file_path=train_file, |
| block_size=128, |
| ) |
|
|
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=self.tokenizer, |
| mlm=False, |
| ) |
|
|
| for epoch in range(epochs): |
| |
| training_args = TrainingArguments( |
| output_dir=f"./results/epoch_{epoch+1}", |
| overwrite_output_dir=True, |
| num_train_epochs=3, |
| per_device_train_batch_size=3, |
| save_steps=-1, |
| save_total_limit=None, |
| prediction_loss_only=True, |
| ) |
|
|
| |
| trainer = Trainer( |
| model=self.model, |
| args=training_args, |
| data_collator=data_collator, |
| train_dataset=train_dataset, |
| ) |
|
|
| |
| trainer.train() |
|
|
| |
| self.model.save_pretrained(self.local_file_path("SaveState")) |
|
|
| if __name__ == "__main__": |
| humor_trainer = GptHumorTrainer() |
| humor_trainer.train(humor_trainer.local_file_path("TrainData.txt"), epochs=5) |
|
|