| import os |
| import time |
| import math |
| import pickle |
| import random |
| import json |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
|
|
| |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer |
| from dataclasses import dataclass |
|
|
| |
| |
| |
| def get_alibi_slopes(n_head): |
| """Compute ALiBi slopes for each head. |
| This implementation follows the approach used in several ALiBi implementations. |
| """ |
| def get_slopes_power_of_2(n): |
| start = 2 ** (-2 ** -(math.log2(n) - 3)) |
| ratio = start |
| return [start * (ratio ** i) for i in range(n)] |
| if math.log2(n_head).is_integer(): |
| slopes = get_slopes_power_of_2(n_head) |
| else: |
| closest_power_of_2 = 2 ** math.floor(math.log2(n_head)) |
| slopes = get_slopes_power_of_2(closest_power_of_2) |
| extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][: n_head - closest_power_of_2] |
| slopes.extend(extra_slopes) |
| return torch.tensor(slopes, dtype=torch.float32) |
|
|
| |
| |
| |
|
|
| class LayerNorm(nn.Module): |
| """LayerNorm with an optional bias.""" |
| def __init__(self, ndim, bias: bool): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(ndim)) |
| self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
|
|
| def forward(self, input): |
| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| self.attn_dropout = nn.Dropout(config.dropout) |
| self.resid_dropout = nn.Dropout(config.dropout) |
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.dropout = config.dropout |
| self.use_rope = config.use_rope |
| self.rope_base = config.rope_base |
| |
| self.use_ape = getattr(config, 'use_ape', False) |
| |
| self.use_alibi = getattr(config, 'use_alibi', False) |
| if self.use_alibi and self.use_ape: |
| raise ValueError("Cannot use both ALiBi and APE simultaneously.") |
| |
| if self.use_ape: |
| self.beta = nn.Parameter(torch.tensor(1.0)) |
| |
| self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
| if (not self.flash) or self.use_ape: |
| self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) |
| .view(1, 1, config.block_size, config.block_size)) |
| |
| def forward(self, x, return_attn_entropy=False, aggregate_heads=False): |
| """ |
| Args: |
| x: Input tensor [B, T, C] |
| return_attn_entropy (bool): If True, return attention entropy. |
| aggregate_heads (bool): If True, average entropy across heads. |
| Returns: |
| y: Output tensor [B, T, C] or (y, entropy) |
| """ |
| B, T, C = x.size() |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
| head_dim = C // self.n_head |
| |
| |
| q = q.view(B, T, self.n_head, head_dim).transpose(1, 2) |
| k = k.view(B, T, self.n_head, head_dim).transpose(1, 2) |
| v = v.view(B, T, self.n_head, head_dim).transpose(1, 2) |
| |
| |
| if self.use_rope: |
| hs = head_dim |
| d = hs // 2 |
| if self.use_ape: |
| theta = 1.0 / (self.rope_base ** (2 * torch.arange(0, d, dtype=x.dtype, device=x.device) / hs)) |
| else: |
| theta = 1.0 / (self.rope_base ** (2 * torch.arange(0, d, dtype=x.dtype, device=x.device) / hs)) |
| t_pos = torch.arange(T, device=x.device, dtype=x.dtype) |
| freqs = torch.outer(t_pos, theta) |
| freqs_cos = torch.cos(freqs).unsqueeze(0).unsqueeze(0) |
| freqs_sin = torch.sin(freqs).unsqueeze(0).unsqueeze(0) |
| def apply_rope(tensor, cos, sin): |
| tensor = tensor.reshape(*tensor.shape[:-1], -1, 2) |
| x0 = tensor[..., 0] |
| x1 = tensor[..., 1] |
| x0_rot = x0 * cos - x1 * sin |
| x1_rot = x0 * sin + x1 * cos |
| return torch.stack([x0_rot, x1_rot], dim=-1).flatten(start_dim=-2) |
| q = apply_rope(q, freqs_cos, freqs_sin) |
| k = apply_rope(k, freqs_cos, freqs_sin) |
| |
| |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_dim)) |
| |
| |
| if self.use_alibi: |
| slopes = get_alibi_slopes(self.n_head).to(x.device) |
| rel_positions = torch.arange(T, device=x.device).unsqueeze(0) - torch.arange(T, device=x.device).unsqueeze(1) |
| alibi_bias = slopes.view(1, self.n_head, 1, 1) * rel_positions.view(1, 1, T, T) |
| att = att - alibi_bias |
| elif self.use_ape: |
| pos_ids = torch.arange(T, device=x.device) |
| rel_dist = pos_ids.unsqueeze(0) - pos_ids.unsqueeze(1) |
| abs_rel = rel_dist.abs().float() |
| temp_matrix = 1.0 / (1.0 + abs_rel) |
| bias_matrix = - self.beta * torch.log(1.0 + abs_rel) |
| temp_matrix = temp_matrix.unsqueeze(0).unsqueeze(0) |
| bias_matrix = bias_matrix.unsqueeze(0).unsqueeze(0) |
| att = temp_matrix * att + bias_matrix |
| |
| p_att = F.softmax(att, dim=-1) |
| entropy = -(p_att * torch.log(p_att + 1e-9)).sum(dim=-1) |
| |
| if self.flash and not self.use_ape: |
| y = torch.nn.functional.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=None, |
| dropout_p=self.dropout if self.training else 0, |
| is_causal=True |
| ) |
| else: |
| if T > self.bias.size(-1): |
| bias = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T) |
| else: |
| bias = self.bias[:, :, :T, :T] |
| att = att.masked_fill(bias == 0, float('-inf')) |
| p_att = F.softmax(att, dim=-1) |
| entropy = -(p_att * torch.log(p_att + 1e-9)).sum(dim=-1) |
| att = self.attn_dropout(p_att) |
| y = att @ v |
| |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| y = self.resid_dropout(self.c_proj(y)) |
| |
| if return_attn_entropy: |
| if aggregate_heads: |
| entropy = entropy.mean(dim=1) |
| return y, entropy |
| else: |
| return y |
|
|
| class MLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) |
| self.gelu = nn.GELU() |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) |
| self.dropout = nn.Dropout(config.dropout) |
| def forward(self, x): |
| x = self.c_fc(x) |
| x = self.gelu(x) |
| x = self.c_proj(x) |
| x = self.dropout(x) |
| return x |
|
|
| class Block(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) |
| self.attn = CausalSelfAttention(config) |
| self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) |
| self.mlp = MLP(config) |
| def forward(self, x, return_attn_entropy=False, aggregate_heads=False): |
| if return_attn_entropy: |
| attn_output, entropy = self.attn(self.ln_1(x), return_attn_entropy=True, aggregate_heads=aggregate_heads) |
| x = x + attn_output |
| x = x + self.mlp(self.ln_2(x)) |
| return x, entropy |
| else: |
| attn_output = self.attn(self.ln_1(x), return_attn_entropy=False) |
| x = x + attn_output |
| x = x + self.mlp(self.ln_2(x)) |
| return x |
|
|
| @dataclass |
| class GPTConfig: |
| block_size: int = 128 |
| vocab_size: int = 50304 |
| n_layer: int = 6 |
| n_head: int = 6 |
| n_embd: int = 384 |
| dropout: float = 0.0 |
| bias: bool = True |
| use_rope: bool = True |
| rope_base: float = 10000.0 |
| use_ape: bool = False |
| lambda_temp: float = 0.1 |
| use_alibi: bool = False |
|
|
| class GPT(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| assert config.vocab_size is not None and config.block_size is not None |
| self.config = config |
| |
| self.use_rope = config.use_rope and not config.use_alibi |
| print(f"Using RoPE in GPT init: {self.use_rope}") |
| self.transformer = nn.ModuleDict(dict( |
| wte = nn.Embedding(config.vocab_size, config.n_embd), |
| wpe = None if self.use_rope else nn.Embedding(config.block_size, config.n_embd), |
| drop = nn.Dropout(config.dropout), |
| h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| ln_f = LayerNorm(config.n_embd, bias=config.bias), |
| )) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| self.transformer.wte.weight = self.lm_head.weight |
| self.apply(self._init_weights) |
| for pn, p in self.named_parameters(): |
| if pn.endswith('c_proj.weight'): |
| torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) |
| print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) |
| def get_num_params(self, non_embedding=True): |
| n_params = sum(p.numel() for p in self.parameters()) |
| if non_embedding and (not self.use_rope) and (self.transformer.wpe is not None): |
| n_params -= self.transformer.wpe.weight.numel() |
| return n_params |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| def forward(self, idx, targets=None, return_attn_entropy=False, aggregate_heads=False): |
| device = idx.device |
| b, t = idx.size() |
| pos = torch.arange(0, t, dtype=torch.long, device=device) |
| tok_emb = self.transformer.wte(idx) |
| if self.use_rope or self.config.use_alibi: |
| x = self.transformer.drop(tok_emb) |
| else: |
| pos_emb = self.transformer.wpe(pos) if self.transformer.wpe is not None else 0 |
| x = self.transformer.drop(tok_emb + pos_emb) |
| attn_entropies = [] |
| for block in self.transformer.h: |
| if return_attn_entropy: |
| x, entropy = block(x, return_attn_entropy=True, aggregate_heads=aggregate_heads) |
| attn_entropies.append(entropy) |
| else: |
| x = block(x) |
| x = self.transformer.ln_f(x) |
| if targets is not None: |
| logits = self.lm_head(x) |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
| else: |
| logits = self.lm_head(x[:, [-1], :]) |
| loss = None |
| if return_attn_entropy: |
| return logits, loss, attn_entropies |
| else: |
| return logits, loss |
| @torch.no_grad() |
| def generate_and_compute_perplexity(self, prompt, ground_truth, temperature=1.0, return_attn_entropy=False, aggregate_heads=False): |
| if return_attn_entropy: |
| _, _, attn_entropies = self(prompt, return_attn_entropy=True, aggregate_heads=aggregate_heads) |
| per_layer_avgs = [entropy.mean().item() for entropy in attn_entropies] |
| avg_entropy = np.mean(per_layer_avgs) |
| else: |
| avg_entropy = None |
| total_loss = 0.0 |
| total_tokens = 0 |
| prompt_length = prompt.size(1) |
| num_target_tokens = ground_truth.size(1) - prompt_length |
| idx = prompt.clone() |
| for i in range(num_target_tokens): |
| logits, _ = self(idx) |
| logits = logits[:, -1, :] / temperature |
| target = ground_truth[:, prompt_length + i] |
| loss = F.cross_entropy(logits, target, reduction='sum') |
| total_loss += loss.item() |
| total_tokens += target.numel() |
| target_token = target.unsqueeze(1) |
| idx = torch.cat((idx, target_token), dim=1) |
| avg_neg_log_likelihood = total_loss / total_tokens if total_tokens > 0 else float('inf') |
| perplexity = math.exp(avg_neg_log_likelihood) |
| return idx, perplexity, avg_entropy |
| @torch.no_grad() |
| def generate_until_end(self, idx, temperature=1.0, top_k=None, max_new_tokens=1000): |
| for i in range(max_new_tokens): |
| idx_cond = idx |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] / temperature |
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
| probs = F.softmax(logits, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat((idx, idx_next), dim=1) |
| if idx_next.item() == 50256: |
| break |
| return idx |
|
|
| |
| |
| |
|
|
| |
| train_data_path = "/data1/home/nitinvetcha/Topics in AI/Streamlined/COLM2025/train_tinystories.bin" |
| val_data_path = "/data1/home/nitinvetcha/Topics in AI/Streamlined/COLM2025/val_tinystories.bin" |
| def get_batch(split): |
| data_path = train_data_path if split == 'train' else val_data_path |
| data = np.memmap(data_path, dtype=np.uint16, mode='r') |
| total_tokens = len(data) |
| max_ix = max(1, total_tokens - gptconf.block_size) |
| ix = torch.randint(0, max_ix, (batch_size,)) |
| X = torch.stack([torch.from_numpy(data[i:i+gptconf.block_size].astype(np.int64)) for i in ix]) |
| Y = torch.stack([torch.from_numpy(data[i+1:i+1+gptconf.block_size].astype(np.int64)) for i in ix]) |
| return X.to(device), Y.to(device) |
|
|
| def evaluate_prompt_perplexity(model, token_file, prompt_length, num_trials, generation_params, device): |
| tokens = np.fromfile(token_file, dtype=np.uint16) |
| total_tokens = len(tokens) |
| perplexities = [] |
| entropy_trials = [] |
| max_new_tokens = generation_params.get("max_new_tokens", 50) |
| total_length = prompt_length + max_new_tokens |
| for trial in range(num_trials): |
| start_idx = random.randint(0, total_tokens - total_length) |
| sequence_tokens = tokens[start_idx : start_idx + total_length] |
| prompt_tokens = sequence_tokens[:prompt_length] |
| ground_truth_tokens = sequence_tokens |
| prompt_tensor = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0).to(device) |
| ground_truth_tensor = torch.tensor(ground_truth_tokens, dtype=torch.long).unsqueeze(0).to(device) |
| _, ppl, trial_entropy = model.generate_and_compute_perplexity( |
| prompt_tensor, ground_truth_tensor, |
| temperature=generation_params.get("temperature", 1.0), |
| return_attn_entropy=True, aggregate_heads=True |
| ) |
| perplexities.append(ppl) |
| entropy_trials.append(trial_entropy) |
| print(f"Trial {trial+1}/{num_trials} for prompt length {prompt_length}: Perplexity = {ppl:.2f}, Avg Entropy = {trial_entropy:.4f}") |
| avg_ppl = np.mean(perplexities) |
| avg_entropy = np.mean(entropy_trials) |
| print(f"Prompt Length {prompt_length} - Avg Perplexity: {avg_ppl:.2f}, Avg Attention Entropy: {avg_entropy:.4f}\n") |
| return avg_ppl, avg_entropy |
|
|
| |
| |
| |
| |
| batch_size = 12 |
| max_iters = 25001 |
| save_interval = 5000 |
| learning_rate = 6e-4 |
| weight_decay = 1e-1 |
| grad_clip = 1.0 |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| torch.manual_seed(1337) |
|
|
| |
| model_args = dict( |
| n_layer=6, |
| n_head=6, |
| n_embd=384, |
| block_size=64, |
| bias=False, |
| use_rope=True, |
| use_ape=True, |
| use_alibi=False, |
| rope_base=10000.0, |
| vocab_size=50304, |
| dropout=0.0 |
| ) |
| gptconf = GPTConfig(**model_args) |
| model = GPT(gptconf).to(device) |
| model.train() |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) |
| iter_num = 0 |
| start_time = time.time() |
| training_losses = [] |
| validation_losses = [] |
| save_iters = [] |
|
|
| |
| flag_parts = [] |
| if gptconf.use_rope: |
| flag_parts.append("rope") |
| if gptconf.use_ape: |
| flag_parts.append("ape") |
| if gptconf.use_alibi: |
| flag_parts.append("alibi") |
| flag_str = "_".join(flag_parts) if flag_parts else "none" |
| weight_dir = f"weights_{flag_str}_{gptconf.block_size}" |
| os.makedirs(weight_dir, exist_ok=True) |
|
|
| while iter_num < max_iters: |
| X_train, Y_train = get_batch('train') |
| optimizer.zero_grad() |
| logits, loss_train = model(X_train, Y_train) |
| loss_train.backward() |
| if grad_clip > 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
| optimizer.step() |
| training_losses.append(loss_train.item()) |
| |
| model.eval() |
| X_val, Y_val = get_batch('val') |
| with torch.no_grad(): |
| logits_val, loss_val = model(X_val, Y_val) |
| validation_losses.append(loss_val.item()) |
| model.train() |
| |
| if iter_num % 100 == 0: |
| elapsed = time.time() - start_time |
| print(f"Iter {iter_num:5d}: train loss = {loss_train.item():.4f}, val loss = {loss_val.item():.4f}, time/iter = {elapsed/(iter_num+1):.4f}s") |
| |
| if iter_num > 0 and iter_num % save_interval == 0: |
| save_iters.append(iter_num) |
| ckpt = { |
| 'iter_num': iter_num, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'training_losses': training_losses, |
| 'validation_losses': validation_losses, |
| 'save_iters': save_iters, |
| } |
| ckpt_path = os.path.join(weight_dir, f"ckpt_{iter_num}.pt") |
| torch.save(ckpt, ckpt_path) |
| print(f"Checkpoint saved to {ckpt_path}") |
| |
| iter_num += 1 |
|
|
| print("Training complete.") |
|
|
| plt.figure(figsize=(10, 6)) |
| plt.plot(range(len(training_losses)), training_losses, label="Training Loss") |
| plt.plot(range(len(validation_losses)), validation_losses, label="Validation Loss", alpha=0.7) |
| plt.xlabel("Iteration") |
| plt.ylabel("Loss") |
| plt.title("Training and Validation Loss per Iteration") |
| plt.legend() |
| plt.grid(True) |
| plt.show() |
|
|
| |
| |
| |
|
|
| token_file = val_data_path |
| prompt_lengths = [64, 128, 256, 512, 1024, 2048, 4096, 8192] |
| num_trials = 5 |
| generation_params = {"temperature": 1.0, "max_new_tokens": 50} |
|
|
| avg_perplexities = [] |
| avg_entropies = [] |
|
|
| for pl in prompt_lengths: |
| print(f"Evaluating for prompt length: {pl}") |
| avg_ppl, avg_entropy = evaluate_prompt_perplexity(model, token_file, pl, num_trials, generation_params, device) |
| avg_perplexities.append(avg_ppl) |
| avg_entropies.append(avg_entropy) |
|
|
| results = { |
| "prompt_lengths": prompt_lengths, |
| "avg_perplexities": avg_perplexities, |
| "avg_entropies": avg_entropies |
| } |
| results_filename = f"results_{flag_str}_{gptconf.block_size}.json" |
| with open(results_filename, "w") as f: |
| json.dump(results, f) |
| print(f"Results saved to {results_filename}") |
|
|
| plt.figure(figsize=(8, 6)) |
| plt.plot(prompt_lengths, avg_perplexities, marker='o') |
| plt.xlabel("Prompt Length") |
| plt.ylabel("Avg Generated Perplexity") |
| plt.title("Avg Generated Perplexity vs Prompt Length") |
| plt.grid(True) |
| plt.xscale('log') |
| plt.savefig(f"avg_generated_perplexity_{flag_str}_{gptconf.block_size}.png") |
| plt.show() |
|
|
| plt.figure(figsize=(8, 6)) |
| plt.plot(prompt_lengths, avg_entropies, marker='o', color='red') |
| plt.xlabel("Prompt Length") |
| plt.ylabel("Avg Attention Entropy") |
| plt.title("Avg Attention Entropy vs Prompt Length\n(Averaged over Layers)") |
| plt.grid(True) |
| plt.xscale('log') |
| plt.savefig(f"avg_attention_entropy_{flag_str}_{gptconf.block_size}.png") |
| plt.show() |
|
|