""" Tuned Lens Runtime — load and apply per-layer affine probes for improved intermediate-layer predictions. Each probe applies a learned linear correction A_l(x) = x @ W_l^T + b_l (initialised to identity + zero during training) that is trained to minimise KL divergence between the corrected layer's predictions and the model's final-layer predictions. See scripts/train_tuned_lens.py for the training pipeline. """ import json import logging import os from pathlib import Path from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn logger = logging.getLogger(__name__) TUNED_LENS_DIR = os.environ.get("TUNED_LENS_DIR", "./tuned_lens_weights") class TunedLensRuntime: """Load, cache, and apply per-layer affine probes at inference time.""" def __init__(self): self._probes: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {} self._metadata: Optional[dict] = None self._available = False @property def available(self) -> bool: return self._available def load(self, model_id: str, device: torch.device, dtype: torch.dtype, weights_dir: Optional[str] = None) -> bool: """Load tuned lens checkpoint for *model_id*. Returns True if weights were loaded successfully, False otherwise. Failure is non-fatal — the system falls back to raw logit lens. """ base_dir = Path(weights_dir or TUNED_LENS_DIR) model_dir = base_dir / model_id if not model_dir.exists(): logger.info(f"Tuned lens: no weights directory for {model_id} at {model_dir}") return False # Find the checkpoint — pick the first .pt file pt_files = sorted(model_dir.glob("tuned_lens_*.pt")) if not pt_files: logger.info(f"Tuned lens: no .pt checkpoint found in {model_dir}") return False checkpoint_path = pt_files[0] metadata_path = model_dir / "metadata.json" try: # Load metadata if metadata_path.exists(): with open(metadata_path, "r") as f: self._metadata = json.load(f) logger.info(f"Tuned lens: metadata loaded — {self._metadata.get('n_layers')} layers, " f"d_model={self._metadata.get('d_model')}") else: self._metadata = {} # Load state dict state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) # Parse layer_N.weight / layer_N.bias entries self._probes = {} layer_indices = set() for key in state_dict: parts = key.split(".") if len(parts) == 2 and parts[0].startswith("layer_"): idx = int(parts[0].split("_")[1]) layer_indices.add(idx) for idx in sorted(layer_indices): w_key = f"layer_{idx}.weight" b_key = f"layer_{idx}.bias" if w_key in state_dict and b_key in state_dict: weight = state_dict[w_key].to(device=device, dtype=dtype) bias = state_dict[b_key].to(device=device, dtype=dtype) self._probes[idx] = (weight, bias) if not self._probes: logger.warning(f"Tuned lens: checkpoint loaded but no layer probes found") return False self._available = True logger.info(f"Tuned lens: loaded {len(self._probes)} layer probes from {checkpoint_path} " f"(device={device}, dtype={dtype})") return True except Exception as e: logger.warning(f"Tuned lens: failed to load checkpoint — {e}") self._probes = {} self._metadata = None self._available = False return False def apply(self, layer_idx: int, hidden_state: torch.Tensor) -> torch.Tensor: """Apply the affine probe for *layer_idx*: hidden @ W^T + b. If no probe exists for this layer, returns the hidden state unchanged (identity fallback). """ if layer_idx not in self._probes: return hidden_state weight, bias = self._probes[layer_idx] return hidden_state @ weight.T + bias def get_info(self) -> dict: """Return metadata dict for health/debug endpoints.""" return { "available": self._available, "num_probes": len(self._probes), "layer_indices": sorted(self._probes.keys()), "metadata": self._metadata or {}, } # Global singleton tuned_lens_runtime = TunedLensRuntime()