LH-Tech-AI commited on
Commit
e1845ba
·
verified ·
1 Parent(s): 123d655

Update finetune.py

Browse files

The official finetuning script.

Files changed (1) hide show
  1. finetune.py +122 -113
finetune.py CHANGED
@@ -1,120 +1,129 @@
1
  import os
 
 
 
 
 
2
  import numpy as np
3
- import tiktoken
4
- from datasets import load_dataset
5
- from tqdm import tqdm
6
-
7
- OUTPUT_DIR = "data/alpaca_cleaned_mixed"
8
- TOKENIZER_NAME = "gpt2"
9
- SEED = 1337
10
-
11
- FINEWEB_SAMPLES = 2500
12
-
13
- enc = tiktoken.get_encoding(TOKENIZER_NAME)
14
- EOS_TOKEN = "<|endoftext|>"
15
-
16
- def format_prompt_with_mask(instruction, input_text, output):
17
- """
18
- Formatiert den Prompt und erstellt die Loss-Maske.
19
- Format:
20
- Instruction: ...
21
- Input: ... (optional)
22
- Response: ... <|endoftext|>
23
- """
24
- if input_text and input_text.strip():
25
- prompt_text = f"Instruction:\n{instruction}\n\nInput:\n{input_text}\n\nResponse:\n"
26
- else:
27
- prompt_text = f"Instruction:\n{instruction}\n\nResponse:\n"
28
-
29
- completion_text = f"{output}{EOS_TOKEN}"
30
-
31
- prompt_ids = enc.encode(prompt_text, allowed_special={'<|endoftext|>'})
32
- completion_ids = enc.encode(completion_text, allowed_special={'<|endoftext|>'})
33
-
34
- full_ids = prompt_ids + completion_ids
35
-
36
- mask = [0] * len(prompt_ids) + [1] * len(completion_ids)
37
-
38
- return full_ids, mask
39
 
40
- def main():
41
- np.random.seed(SEED)
42
- print(f"🚀 Starting Prepare-Script for SmaLLMPro (350M Instruct)...")
43
- print(f"📚 Tokenizer: {TOKENIZER_NAME}")
44
-
45
- os.makedirs(OUTPUT_DIR, exist_ok=True)
46
 
47
- print("📥 Loading 'yahma/alpaca-cleaned' (Chat-Instructions)...")
48
- alpaca = load_dataset("yahma/alpaca-cleaned", split='train')
49
-
50
- print(f"📥 Loading 'HuggingFaceFW/fineweb-edu' (Sample-10BT) for {FINEWEB_SAMPLES} Samples...")
51
- fineweb = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split='train', streaming=True)
52
-
53
- all_tokens = []
54
- all_masks = []
55
-
56
- print("⚙️ Processing Alpaca...")
57
- for ex in tqdm(alpaca, desc="Alpaca"):
58
- ids, mask = format_prompt_with_mask(ex['instruction'], ex['input'], ex['output'])
59
- all_tokens.extend(ids)
60
- all_masks.extend(mask)
61
-
62
- alpaca_len = len(all_tokens)
63
- print(f" -> Alpaca Tokens: {alpaca_len:,}")
64
-
65
- print("⚙️ Processing FineWeb (Anti-Forgetting)...")
66
- fw_iter = iter(fineweb)
67
- fw_count = 0
68
- fw_tokens_count = 0
69
-
70
- for _ in tqdm(range(FINEWEB_SAMPLES), desc="FineWeb"):
71
- try:
72
- ex = next(fw_iter)
73
- text = ex['text'] + EOS_TOKEN
74
- ids = enc.encode(text, allowed_special={EOS_TOKEN})
75
-
76
- all_tokens.extend(ids)
77
- all_masks.extend([1] * len(ids))
78
-
79
- fw_tokens_count += len(ids)
80
- fw_count += 1
81
- except StopIteration:
82
- break
83
-
84
- print(f" -> FineWeb Tokens: {fw_tokens_count:,} (from {fw_count} documents)")
85
-
86
- total_tokens = len(all_tokens)
87
- print(f"\n💾 Saving {total_tokens:,} Tokens in '{OUTPUT_DIR}'...")
88
-
89
- token_arr = np.array(all_tokens, dtype=np.uint16)
90
- token_arr.tofile(os.path.join(OUTPUT_DIR, "train.bin"))
91
-
92
- mask_arr = np.array(all_masks, dtype=np.uint8)
93
- mask_arr.tofile(os.path.join(OUTPUT_DIR, "train_mask.bin"))
94
-
95
- print("\n🔍 --- SANITY CHECK ---")
96
- print("I decode the first 50 tokens of the first sample, to check, if everything is okay.")
97
- print("Green (TRAIN) = The things the model learns. Grey (IGNORE) = The things the model only reads.")
98
-
99
- check_len = 100
100
- sample_ids = all_tokens[:check_len]
101
- sample_mask = all_masks[:check_len]
102
 
103
- decoded_parts = []
104
- for t_id, m_val in zip(sample_ids, sample_mask):
105
- token_str = enc.decode([t_id])
106
- if m_val == 1:
107
- decoded_parts.append(f"\033[92m{token_str}\033[0m")
108
- else:
109
- decoded_parts.append(f"\033[90m{token_str}\033[0m")
110
-
111
- print("".join(decoded_parts))
112
- print("\n(Legend: \033[90mGrey=Prompt/Ignored\033[0m, \033[Green=Response/Learned\033[0m)")
113
 
114
- if len(token_arr) != len(mask_arr):
115
- print("\n❌ Warning: Token and Mask Array have different lengths! Something has gone wrong!")
116
- else:
117
- print("\n✅ Everything seems to be fine. The arrays are synchronized. You can now start the training.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- if __name__ == "__main__":
120
- main()
 
 
 
 
 
 
 
1
  import os
2
+ import time
3
+ import math
4
+ import torch
5
+ from model import GPTConfig, GPT
6
+
7
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ out_dir = '/home/user/350m_SmaLLMPro_Final'
10
+ init_from = '/home/user/350m_fineweb'
11
+ dataset = 'alpaca_cleaned_mixed'
 
 
 
12
 
13
+ batch_size = 4
14
+ gradient_accumulation_steps = 32
15
+ block_size = 1024
16
+ learning_rate = 2e-5
17
+ max_iters = 1500
18
+ weight_decay = 0.1
19
+ dropout = 0.1
20
+ warmup_iters = 0
21
+ min_lr = 3e-6
22
+ beta1, beta2 = 0.9, 0.95
23
+ device = 'cuda'
24
+ dtype = 'bfloat16'
25
+ compile = True
26
+ save_interval = 500
27
+
28
+ os.makedirs(out_dir, exist_ok=True)
29
+ torch.manual_seed(1337)
30
+ device_type = 'cuda' if 'cuda' in device else 'cpu'
31
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
32
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
33
+
34
+ data_dir = os.path.join('data', dataset)
35
+ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
36
+ train_mask = np.memmap(os.path.join(data_dir, 'train_mask.bin'), dtype=np.uint8, mode='r')
37
+
38
+ def get_batch():
39
+ ix = torch.randint(len(train_data) - block_size, (batch_size,))
40
+ x = torch.stack([torch.from_numpy((train_data[i:i+block_size]).astype(np.int64)) for i in ix])
41
+ y = torch.stack([torch.from_numpy((train_data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
42
+ m = torch.stack([torch.from_numpy((train_mask[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ y[m == 0] = -100
 
 
 
 
 
 
 
 
 
45
 
46
+ x, y = x.to(device), y.to(device)
47
+ return x, y
48
+
49
+ print(f"📥 Loading Pretraining-Checkpoint from {init_from}...")
50
+ ckpt_files = sorted([f for f in os.listdir(init_from) if f.endswith('.pt')])
51
+ if not ckpt_files:
52
+ raise FileNotFoundError("No checkpoint found in init_from directory!")
53
+
54
+ ckpt_path = os.path.join(init_from, ckpt_files[-1])
55
+ checkpoint = torch.load(ckpt_path, map_location=device)
56
+ gptconf = GPTConfig(**checkpoint['model_args'])
57
+ model = GPT(gptconf)
58
+ state_dict = checkpoint['model']
59
+
60
+ unwanted_prefix = '_orig_mod.'
61
+ for k,v in list(state_dict.items()):
62
+ if k.startswith(unwanted_prefix):
63
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
64
+
65
+ model.load_state_dict(state_dict)
66
+ model.to(device)
67
+
68
+ if compile:
69
+ print("🚀 Compiling Model...")
70
+ model = torch.compile(model)
71
+
72
+ optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
73
+ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
74
+
75
+ def get_lr(it):
76
+ if it < warmup_iters: return learning_rate * it / warmup_iters
77
+ if it > max_iters: return min_lr
78
+ decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
79
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
80
+ return min_lr + coeff * (learning_rate - min_lr)
81
+
82
+ print(f"🛠️ Starting Finetuning...")
83
+ model.train()
84
+ t0 = time.time()
85
+
86
+ for iter_num in range(max_iters + 1):
87
+ lr = get_lr(iter_num)
88
+ for param_group in optimizer.param_groups:
89
+ param_group['lr'] = lr
90
+
91
+ for micro_step in range(gradient_accumulation_steps):
92
+ X, Y = get_batch()
93
+ with ctx:
94
+ logits, loss = model(X, Y)
95
+ loss = loss / gradient_accumulation_steps
96
+ scaler.scale(loss).backward()
97
+
98
+ scaler.unscale_(optimizer)
99
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
100
+ scaler.step(optimizer)
101
+ scaler.update()
102
+ optimizer.zero_grad(set_to_none=True)
103
+
104
+ if iter_num % 10 == 0:
105
+ dt = time.time() - t0
106
+ print(f"Iter {iter_num}: Loss {loss.item()*gradient_accumulation_steps:.4f}, Time {dt*1000:.2f}ms, LR {lr:.2e}")
107
+ t0 = time.time()
108
+
109
+ if iter_num > 0 and iter_num % save_interval == 0:
110
+ checkpoint_name = f'SmaLLMPro_iter_{iter_num}.pt'
111
+ save_path = os.path.join(out_dir, checkpoint_name)
112
+ print(f"💾 Saving checkpoint: {checkpoint_name}")
113
+ raw_model = model._orig_mod if compile else model
114
+ checkpoint_data = {
115
+ 'model': raw_model.state_dict(),
116
+ 'model_args': checkpoint['model_args'],
117
+ 'iter_num': iter_num,
118
+ 'lr': lr,
119
+ }
120
+ torch.save(checkpoint_data, save_path)
121
 
122
+ print(f"💾 Finetuning done. Saving SmaLLMPro...")
123
+ final_checkpoint = {
124
+ 'model': model.state_dict() if not compile else model._orig_mod.state_dict(),
125
+ 'model_args': checkpoint['model_args'],
126
+ 'config': checkpoint.get('config', {}),
127
+ }
128
+ torch.save(final_checkpoint, os.path.join(out_dir, 'SmaLLMPro_Final.pt'))
129
+ print("✅ SmaLLMPro saved successfully!")