dkhan05 commited on
Commit
10161d9
·
verified ·
1 Parent(s): 8ef6488

Create basic

Browse files
Files changed (1) hide show
  1. basic +36 -0
basic ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
2
+ from datasets import load_dataset
3
+
4
+ # Load dataset
5
+ dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
6
+
7
+ # Load tokenizer and model
8
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
9
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
10
+
11
+ # Tokenize dataset
12
+ def tokenize_function(examples):
13
+ return tokenizer(examples["text"], padding="max_length", truncation=True)
14
+
15
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
16
+
17
+ # Training arguments
18
+ training_args = TrainingArguments(
19
+ output_dir="./results",
20
+ num_train_epochs=3,
21
+ per_device_train_batch_size=4,
22
+ save_steps=10_000,
23
+ save_total_limit=2,
24
+ logging_dir="./logs",
25
+ )
26
+
27
+ # Trainer
28
+ trainer = Trainer(
29
+ model=model,
30
+ args=training_args,
31
+ train_dataset=tokenized_datasets["train"],
32
+ eval_dataset=tokenized_datasets["validation"],
33
+ )
34
+
35
+ # Train the model
36
+ trainer.train()