Spaces:
Sleeping
Sleeping
| """ | |
| Train Tuned Lens probes — per-layer affine corrections that minimise | |
| KL divergence between an intermediate layer's corrected predictions and | |
| the model's final-layer predictions. | |
| Usage: | |
| python -m scripts.train_tuned_lens \ | |
| --model-id codegen-350m \ | |
| --corpus-file calibration_data.txt \ | |
| --output-dir ./tuned_lens_weights/ \ | |
| --max-samples 2000 --epochs 5 | |
| Each probe is a simple affine map A_l(x) = x @ W_l^T + b_l | |
| initialised to identity + zero so that the untrained probe reproduces | |
| the raw logit lens exactly. | |
| """ | |
| import argparse | |
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # AffineProbe | |
| # --------------------------------------------------------------------------- | |
| class AffineProbe(nn.Module): | |
| """Per-layer affine correction initialised to identity.""" | |
| def __init__(self, d_model: int): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.eye(d_model)) | |
| self.bias = nn.Parameter(torch.zeros(d_model)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x @ self.weight.T + self.bias | |
| # --------------------------------------------------------------------------- | |
| # Architecture detection — mirrors model_service.py | |
| # --------------------------------------------------------------------------- | |
| def get_final_ln_and_lm_head(model): | |
| """Return (final_layer_norm, lm_head) for the loaded model.""" | |
| # Mistral / LLaMA / CodeGen-style | |
| if hasattr(model, "model") and hasattr(model.model, "norm"): | |
| return model.model.norm, model.lm_head | |
| # GPT-style | |
| if hasattr(model, "transformer") and hasattr(model.transformer, "ln_f"): | |
| return model.transformer.ln_f, model.lm_head | |
| raise RuntimeError( | |
| "Cannot detect final layer norm — model architecture not recognised. " | |
| "Supported: Mistral/LLaMA (.model.norm), GPT (.transformer.ln_f)" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Model hash — ties checkpoint to exact model weights | |
| # --------------------------------------------------------------------------- | |
| def compute_model_hash(model, n_tensors: int = 20) -> str: | |
| """SHA-256 of the first *n_tensors* parameter tensors' bytes.""" | |
| h = hashlib.sha256() | |
| for i, (_, param) in enumerate(model.named_parameters()): | |
| if i >= n_tensors: | |
| break | |
| h.update(param.data.cpu().numpy().tobytes()) | |
| return h.hexdigest() | |
| # --------------------------------------------------------------------------- | |
| # Corpus loader | |
| # --------------------------------------------------------------------------- | |
| def load_corpus(path: str, max_samples: int, max_seq_len: int, tokenizer) -> list: | |
| """Load and tokenize a plain-text corpus (one sample per line or paragraph).""" | |
| texts = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| buf = [] | |
| for line in f: | |
| line = line.rstrip("\n") | |
| if line.strip() == "" and buf: | |
| texts.append("\n".join(buf)) | |
| buf = [] | |
| if len(texts) >= max_samples: | |
| break | |
| else: | |
| buf.append(line) | |
| if buf and len(texts) < max_samples: | |
| texts.append("\n".join(buf)) | |
| # Tokenize | |
| samples = [] | |
| for text in texts[:max_samples]: | |
| ids = tokenizer.encode(text, add_special_tokens=False, truncation=True, | |
| max_length=max_seq_len) | |
| if len(ids) >= 8: # skip very short sequences | |
| samples.append(torch.tensor(ids, dtype=torch.long)) | |
| logger.info(f"Loaded {len(samples)} samples from {path} (max_seq_len={max_seq_len})") | |
| return samples | |
| # --------------------------------------------------------------------------- | |
| # Training | |
| # --------------------------------------------------------------------------- | |
| def train_tuned_lens( | |
| model, | |
| tokenizer, | |
| samples: list, | |
| device: torch.device, | |
| lr: float = 1e-3, | |
| l2_weight: float = 1e-4, | |
| epochs: int = 5, | |
| ): | |
| """Train one AffineProbe per layer, streaming hidden states (no disk storage).""" | |
| final_ln, lm_head = get_final_ln_and_lm_head(model) | |
| config = model.config | |
| d_model = getattr(config, "hidden_size", None) or getattr(config, "n_embd") | |
| n_layers = getattr(config, "num_hidden_layers", None) or getattr(config, "n_layer") | |
| # Create probes + optimizers | |
| probes = {} | |
| optimizers = {} | |
| for l in range(n_layers): | |
| probe = AffineProbe(d_model).to(device) | |
| probes[l] = probe | |
| optimizers[l] = torch.optim.AdamW(probe.parameters(), lr=lr, weight_decay=0.0) | |
| logger.info(f"Training {n_layers} probes (d_model={d_model}, {len(samples)} samples, {epochs} epochs)") | |
| for epoch in range(epochs): | |
| epoch_losses = {l: 0.0 for l in range(n_layers)} | |
| epoch_count = 0 | |
| for si, sample_ids in enumerate(samples): | |
| input_ids = sample_ids.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states # tuple of (n_layers+1) tensors | |
| # Reference distribution from final layer | |
| ref_hidden = hidden_states[-1] | |
| ref_normed = final_ln(ref_hidden) | |
| ref_logits = lm_head(ref_normed) | |
| ref_log_probs = F.log_softmax(ref_logits, dim=-1).detach() | |
| # Train each layer's probe independently | |
| for l in range(n_layers): | |
| probe = probes[l] | |
| optimizer = optimizers[l] | |
| # hidden_states[0] = embedding, hidden_states[l+1] = after layer l | |
| h = hidden_states[l + 1].detach() | |
| corrected = probe(h) | |
| corrected_normed = final_ln(corrected) | |
| probe_logits = lm_head(corrected_normed) | |
| probe_log_probs = F.log_softmax(probe_logits, dim=-1) | |
| # KL(ref || probe) — ref is the target distribution | |
| kl = F.kl_div(probe_log_probs, ref_log_probs.exp(), reduction="batchmean", log_target=False) | |
| # L2 regularisation toward identity: ||W - I||^2 + ||b||^2 | |
| identity = torch.eye(d_model, device=device, dtype=probe.weight.dtype) | |
| l2_reg = ((probe.weight - identity) ** 2).sum() + (probe.bias ** 2).sum() | |
| loss = kl + l2_weight * l2_reg | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| epoch_losses[l] += loss.item() | |
| epoch_count += 1 | |
| # Free memory | |
| del outputs, hidden_states, ref_hidden, ref_normed, ref_logits, ref_log_probs | |
| if (si + 1) % 100 == 0: | |
| avg_loss = sum(epoch_losses[l] for l in range(n_layers)) / (n_layers * epoch_count) | |
| logger.info(f" Epoch {epoch+1}, sample {si+1}/{len(samples)}, avg loss: {avg_loss:.4f}") | |
| avg_epoch_loss = sum(epoch_losses[l] for l in range(n_layers)) / (n_layers * max(epoch_count, 1)) | |
| logger.info(f"Epoch {epoch+1}/{epochs} complete — avg loss: {avg_epoch_loss:.4f}") | |
| return probes | |
| # --------------------------------------------------------------------------- | |
| # Checkpoint saving | |
| # --------------------------------------------------------------------------- | |
| def save_checkpoint(probes: dict, model, model_id: str, output_dir: str, | |
| training_config: dict): | |
| """Save probe state dicts and metadata.""" | |
| model_hash = compute_model_hash(model) | |
| config = model.config | |
| d_model = getattr(config, "hidden_size", None) or getattr(config, "n_embd") | |
| n_layers = getattr(config, "num_hidden_layers", None) or getattr(config, "n_layer") | |
| save_dir = Path(output_dir) / model_id | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| # Build combined state dict | |
| state_dict = {} | |
| for layer_idx, probe in probes.items(): | |
| state_dict[f"layer_{layer_idx}.weight"] = probe.weight.data.cpu() | |
| state_dict[f"layer_{layer_idx}.bias"] = probe.bias.data.cpu() | |
| checkpoint_path = save_dir / f"tuned_lens_{model_hash[:16]}.pt" | |
| torch.save(state_dict, checkpoint_path) | |
| metadata = { | |
| "model_id": model_id, | |
| "model_hash": model_hash, | |
| "n_layers": n_layers, | |
| "d_model": d_model, | |
| "training_config": training_config, | |
| "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| } | |
| metadata_path = save_dir / "metadata.json" | |
| with open(metadata_path, "w") as f: | |
| json.dump(metadata, f, indent=2) | |
| logger.info(f"Saved checkpoint to {checkpoint_path} ({checkpoint_path.stat().st_size / 1024 / 1024:.1f}MB)") | |
| logger.info(f"Saved metadata to {metadata_path}") | |
| return checkpoint_path | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train tuned lens probes for a model") | |
| parser.add_argument("--model-id", required=True, help="Model identifier (e.g. codegen-350m)") | |
| parser.add_argument("--model-name", default=None, | |
| help="HuggingFace model name (defaults to model-id)") | |
| parser.add_argument("--corpus-file", required=True, help="Plain-text calibration corpus") | |
| parser.add_argument("--output-dir", default="./tuned_lens_weights/", | |
| help="Output directory for checkpoints") | |
| parser.add_argument("--max-samples", type=int, default=2000) | |
| parser.add_argument("--max-seq-len", type=int, default=512) | |
| parser.add_argument("--epochs", type=int, default=5) | |
| parser.add_argument("--lr", type=float, default=1e-3) | |
| parser.add_argument("--l2-weight", type=float, default=1e-4) | |
| parser.add_argument("--device", default=None, help="Device (auto-detected if omitted)") | |
| parser.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"]) | |
| args = parser.parse_args() | |
| model_name = args.model_name or args.model_id | |
| dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} | |
| dtype = dtype_map[args.dtype] | |
| if args.device: | |
| device = torch.device(args.device) | |
| elif torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| logger.info(f"Device: {device}, dtype: {dtype}") | |
| logger.info(f"Loading model: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device) | |
| model.eval() | |
| samples = load_corpus(args.corpus_file, args.max_samples, args.max_seq_len, tokenizer) | |
| if not samples: | |
| logger.error("No valid samples loaded — aborting") | |
| sys.exit(1) | |
| training_config = { | |
| "lr": args.lr, | |
| "l2_weight": args.l2_weight, | |
| "epochs": args.epochs, | |
| "max_samples": args.max_samples, | |
| "max_seq_len": args.max_seq_len, | |
| "dtype": args.dtype, | |
| "num_samples_used": len(samples), | |
| } | |
| probes = train_tuned_lens( | |
| model, tokenizer, samples, device, | |
| lr=args.lr, l2_weight=args.l2_weight, epochs=args.epochs, | |
| ) | |
| save_checkpoint(probes, model, args.model_id, args.output_dir, training_config) | |
| logger.info("Done.") | |
| if __name__ == "__main__": | |
| main() | |