#!/usr/bin/env python3 """ ARC ADAPTER TRAINING - COMPLETE VERSION ======================================== Trains the combined ARC adapter on a FROZEN base model. Components: - Shared fiber projections (4096 → 16 dim) - Repetition detection head (target: 50×+ separation) - Hedging detection head - Verbosity detection head - Sycophancy detection head - Loop 4 tokenizer expansion - Learned intervention thresholds Base model: COMPLETELY FROZEN (never modified) Adapter: ~2M trainable parameters Author: Logan Napolitano Date: January 2026 """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import Dataset, DataLoader import numpy as np import json import re import gc import os import time from pathlib import Path from dataclasses import dataclass, field from typing import List, Dict, Optional, Tuple from collections import defaultdict from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from tqdm import tqdm import warnings warnings.filterwarnings("ignore") # ============================================================================= # CONFIGURATION # ============================================================================= @dataclass class ARCAdapterConfig: """Complete configuration for ARC Adapter training.""" # Paths base_model_path: str = "." output_dir: str = "arc_adapter" # Device device: str = "cuda" # Model architecture (auto-filled from base model) hidden_dim: int = 4096 fiber_dim: int = 16 probe_layers: List[int] = field(default_factory=lambda: [8, 16, 24]) # Data generation settings n_samples_per_head: int = 300 max_gen_tokens: int = 80 repetition_window: int = 32 # Training settings epochs: int = 15 batch_size: int = 32 learning_rate: float = 1e-4 weight_decay: float = 0.01 warmup_steps: int = 100 # Target separations for each head target_separation: Dict[str, float] = field(default_factory=lambda: { "repetition": 50.0, # We've achieved 125×, so 50× is conservative "hedging": 5.0, "verbosity": 5.0, "sycophancy": 3.0, }) # Loop 4 settings loop4_iterations: int = 3 n_merges_per_iteration: int = 10 min_pair_frequency: int = 2 # Intervention defaults (learned during training) default_thresholds: Dict[str, float] = field(default_factory=lambda: { "repetition": 0.1, "hedging": 0.3, "verbosity": 0.4, "sycophancy": 0.4, }) default_penalty_strength: float = 2.0 # EMA settings for control field ema_alpha: float = 0.15 # ============================================================================= # ADAPTER ARCHITECTURE # ============================================================================= class FiberProjection(nn.Module): """ Projects hidden states from multiple layers to shared fiber space. This is the geometric core of CF-HoT - compressing high-dimensional hidden states to a low-dimensional manifold where behavioral tendencies are linearly separable. """ def __init__(self, hidden_dim: int, fiber_dim: int, n_layers: int): super().__init__() self.hidden_dim = hidden_dim self.fiber_dim = fiber_dim self.n_layers = n_layers # Per-layer projection matrices self.projections = nn.ModuleList([ nn.Linear(hidden_dim, fiber_dim, bias=True) for _ in range(n_layers) ]) # Learned layer importance weights self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers) # Initialize projections for proj in self.projections: nn.init.xavier_uniform_(proj.weight) nn.init.zeros_(proj.bias) def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: """ Project list of hidden states to fiber space. Args: hidden_states: List of [batch, seq, hidden_dim] tensors Returns: fiber: [batch, seq, fiber_dim] """ weights = F.softmax(self.layer_weights, dim=0) fiber = None for i, (h, proj) in enumerate(zip(hidden_states, self.projections)): # Cast to float32 for adapter computation h = h.float() projected = proj(h) if fiber is None: fiber = weights[i] * projected else: fiber = fiber + weights[i] * projected return fiber def forward_stacked(self, hidden_stack: torch.Tensor) -> torch.Tensor: """ Project stacked hidden states to fiber space. Args: hidden_stack: [batch, n_layers, hidden_dim] Returns: fiber: [batch, fiber_dim] """ # Cast to float32 for adapter computation (model outputs bfloat16) hidden_stack = hidden_stack.float() weights = F.softmax(self.layer_weights, dim=0) batch_size = hidden_stack.shape[0] fiber = torch.zeros( batch_size, self.fiber_dim, device=hidden_stack.device, dtype=torch.float32 ) for i, proj in enumerate(self.projections): fiber = fiber + weights[i] * proj(hidden_stack[:, i, :]) return fiber class BehaviorHead(nn.Module): """ Single behavioral detection head. Takes fiber state, outputs probability of specific behavior. Architecture: fiber_dim → 64 → 16 → 1 """ def __init__(self, fiber_dim: int, name: str): super().__init__() self.name = name self.fiber_dim = fiber_dim self.classifier = nn.Sequential( nn.Linear(fiber_dim, 64), nn.ReLU(), nn.Dropout(0.1), nn.Linear(64, 16), nn.ReLU(), nn.Dropout(0.05), nn.Linear(16, 1), ) # Initialize for module in self.classifier: if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) nn.init.zeros_(module.bias) def forward(self, fiber: torch.Tensor) -> torch.Tensor: """ Get logits from fiber state. Args: fiber: [batch, fiber_dim] or [batch, seq, fiber_dim] Returns: logits: [batch] or [batch, seq] """ logits = self.classifier(fiber) return logits.squeeze(-1) def predict_proba(self, fiber: torch.Tensor) -> torch.Tensor: """Get probabilities.""" return torch.sigmoid(self.forward(fiber)) class ARCAdapter(nn.Module): """ Complete ARC Adapter module. Contains: - Shared fiber projection (geometry) - Multiple behavioral heads (detection) - Intervention parameters (control) - EMA tracking (temporal smoothing) """ def __init__(self, config: ARCAdapterConfig): super().__init__() self.config = config # Shared fiber projection self.fiber_proj = FiberProjection( hidden_dim=config.hidden_dim, fiber_dim=config.fiber_dim, n_layers=len(config.probe_layers) ) # Behavioral detection heads self.heads = nn.ModuleDict({ "repetition": BehaviorHead(config.fiber_dim, "repetition"), "hedging": BehaviorHead(config.fiber_dim, "hedging"), "verbosity": BehaviorHead(config.fiber_dim, "verbosity"), "sycophancy": BehaviorHead(config.fiber_dim, "sycophancy"), }) # Learned intervention thresholds self.thresholds = nn.ParameterDict({ name: nn.Parameter(torch.tensor(thresh)) for name, thresh in config.default_thresholds.items() }) # Learned penalty strength self.penalty_strength = nn.Parameter( torch.tensor(config.default_penalty_strength) ) # EMA state for control field accumulation self.ema_alpha = config.ema_alpha self.register_buffer('_ema_initialized', torch.tensor(False)) self._ema_states: Dict[str, Optional[float]] = {} self.reset_ema() def reset_ema(self): """Reset EMA states for new generation.""" self._ema_states = {name: None for name in self.heads.keys()} def forward( self, hidden_states: List[torch.Tensor] ) -> Dict[str, torch.Tensor]: """ Full forward pass through adapter. Args: hidden_states: List of hidden states from probe layers Returns: Dict mapping head_name → logits """ fiber = self.fiber_proj(hidden_states) predictions = {} for name, head in self.heads.items(): predictions[name] = head(fiber) return predictions def get_fiber(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: """Get fiber representation.""" return self.fiber_proj(hidden_states) def get_risks( self, hidden_states: List[torch.Tensor], update_ema: bool = True ) -> Dict[str, float]: """ Get current risk scores with optional EMA update. Args: hidden_states: List of [1, 1, hidden_dim] tensors (last position) update_ema: Whether to update EMA states Returns: Dict mapping head_name → risk score (0-1) """ # Stack and project # hidden_states is list of [batch, seq, hidden_dim] # We want the last position: [batch, n_layers, hidden_dim] stacked = torch.stack([h[:, -1, :] for h in hidden_states], dim=1) fiber = self.fiber_proj.forward_stacked(stacked) risks = {} for name, head in self.heads.items(): with torch.no_grad(): prob = head.predict_proba(fiber).mean().item() if update_ema: if self._ema_states[name] is None: self._ema_states[name] = prob else: self._ema_states[name] = ( self.ema_alpha * prob + (1 - self.ema_alpha) * self._ema_states[name] ) risks[name] = self._ema_states[name] else: risks[name] = prob return risks def compute_intervention( self, risks: Dict[str, float], recent_tokens: List[int], window_size: int = 32 ) -> Dict[int, float]: """ Compute logit penalties based on current risks. Args: risks: Current risk scores from get_risks() recent_tokens: Recently generated token IDs window_size: How far back to penalize repetitions Returns: Dict mapping token_id → penalty amount """ penalties = {} # Repetition intervention rep_risk = risks.get("repetition", 0) rep_thresh = self.thresholds["repetition"].item() if rep_risk > rep_thresh: # Scale penalty by how much we exceed threshold strength = self.penalty_strength.item() * (rep_risk / rep_thresh) # Penalize recently used tokens recent = recent_tokens[-window_size:] if len(recent_tokens) > window_size else recent_tokens for token_id in set(recent): penalties[token_id] = penalties.get(token_id, 0) + strength # Could add hedging/verbosity interventions here # (e.g., penalize "As an AI" type tokens) return penalties def get_param_count(self) -> int: """Get total trainable parameter count.""" return sum(p.numel() for p in self.parameters() if p.requires_grad) def save(self, path: str): """Save adapter to directory.""" path = Path(path) path.mkdir(parents=True, exist_ok=True) # Save model weights torch.save(self.state_dict(), path / "adapter_weights.pt") # Save config as JSON config_dict = { "hidden_dim": self.config.hidden_dim, "fiber_dim": self.config.fiber_dim, "probe_layers": self.config.probe_layers, "ema_alpha": self.ema_alpha, "thresholds": { name: self.thresholds[name].item() for name in self.thresholds }, "penalty_strength": self.penalty_strength.item(), "head_names": list(self.heads.keys()), } with open(path / "adapter_config.json", "w") as f: json.dump(config_dict, f, indent=2) print(f"💾 Adapter saved to {path}") print(f" Weights: adapter_weights.pt") print(f" Config: adapter_config.json") @classmethod def load(cls, path: str, device: str = "cuda") -> "ARCAdapter": """Load adapter from directory.""" path = Path(path) # Load config with open(path / "adapter_config.json") as f: config_dict = json.load(f) # Create config object config = ARCAdapterConfig( hidden_dim=config_dict["hidden_dim"], fiber_dim=config_dict["fiber_dim"], probe_layers=config_dict["probe_layers"], ema_alpha=config_dict["ema_alpha"], default_thresholds=config_dict["thresholds"], default_penalty_strength=config_dict["penalty_strength"], ) # Create adapter adapter = cls(config) # Load weights state_dict = torch.load(path / "adapter_weights.pt", map_location=device) adapter.load_state_dict(state_dict) return adapter.to(device) # ============================================================================= # DATA GENERATION # ============================================================================= class DataGenerator: """ Generates labeled training data for all behavioral heads. Each head needs different types of data: - Repetition: Token-level labels during generation - Hedging: Prompt-level labels based on response patterns - Verbosity: Prompt-level labels based on response length/density - Sycophancy: Prompt-level labels based on agreement patterns """ # Pattern definitions HEDGING_PATTERNS = [ r"as an ai", r"i cannot", r"i can't", r"i don't have", r"i'm not able", r"i am not able", r"i don't actually", r"i'm just", r"i am just", r"my purpose is", r"i was designed", r"i don't experience", r"i'm an ai", r"as a language model", ] SYCOPHANCY_PATTERNS = [ r"great question", r"excellent point", r"you're (absolutely )?right", r"that's a (great|excellent|wonderful)", r"i (completely |totally )?agree", r"absolutely[,!]", r"definitely[,!]", r"of course[,!]", r"you make a (great|excellent|good) point", ] def __init__(self, model, tokenizer, config: ARCAdapterConfig): self.model = model self.tokenizer = tokenizer self.config = config self.device = config.device # Compile patterns self.hedging_patterns = [ re.compile(p, re.IGNORECASE) for p in self.HEDGING_PATTERNS ] self.sycophancy_patterns = [ re.compile(p, re.IGNORECASE) for p in self.SYCOPHANCY_PATTERNS ] def is_repetition(self, tokens: List[int], position: int) -> bool: """Check if token at position repeats within window.""" if position < 1: return False current = tokens[position] start = max(0, position - self.config.repetition_window) return current in tokens[start:position] def is_hedging(self, text: str) -> bool: """Check if text contains hedging patterns.""" return any(p.search(text) for p in self.hedging_patterns) def is_sycophantic(self, text: str) -> bool: """Check if text contains sycophancy patterns.""" return any(p.search(text) for p in self.sycophancy_patterns) def is_verbose(self, text: str, token_count: int) -> bool: """ Check if response is verbose. Verbose = low information density or excessive length. """ words = text.split() if len(words) < 10: return False # Unique word ratio unique_ratio = len(set(w.lower() for w in words)) / len(words) # Verbose if low uniqueness or very long return unique_ratio < 0.5 or token_count > 100 def extract_hidden_states( self, input_ids: torch.Tensor ) -> torch.Tensor: """ Extract hidden states at probe layers for last position. Args: input_ids: [1, seq_len] Returns: hidden_stack: [n_layers, hidden_dim] """ with torch.no_grad(): outputs = self.model( input_ids, output_hidden_states=True, ) hidden_list = [] for layer_idx in self.config.probe_layers: # Get last position: [hidden_dim] h = outputs.hidden_states[layer_idx][0, -1, :].cpu() hidden_list.append(h) return torch.stack(hidden_list) # [n_layers, hidden_dim] def generate_repetition_data( self, prompts: List[str] ) -> Dict[str, List]: """ Generate token-level labeled data for repetition detection. For each generated token, we capture: - Hidden states at probe layers (before generating the token) - Label: 1 if the token repeats within window, 0 otherwise """ all_hidden = [] all_labels = [] print(f"\n📊 Generating repetition training data...") print(f" Prompts: {len(prompts)}") print(f" Max tokens per prompt: {self.config.max_gen_tokens}") for prompt in tqdm(prompts, desc="Repetition data"): try: inputs = self.tokenizer( prompt, return_tensors="pt" ).to(self.device) generated_ids = inputs.input_ids[0].tolist() for step in range(self.config.max_gen_tokens): # Current sequence as tensor input_tensor = torch.tensor([generated_ids]).to(self.device) # Extract hidden states BEFORE generating next token hidden_stack = self.extract_hidden_states(input_tensor) # Generate next token with torch.no_grad(): outputs = self.model(input_tensor) logits = outputs.logits[0, -1, :] probs = F.softmax(logits / 0.8, dim=-1) next_token = torch.multinomial(probs, 1).item() # Record position and add token position = len(generated_ids) generated_ids.append(next_token) # Label: did this token repeat? is_rep = self.is_repetition(generated_ids, position) all_hidden.append(hidden_stack) all_labels.append(1 if is_rep else 0) # Stop at EOS if next_token == self.tokenizer.eos_token_id: break except Exception as e: print(f" Error on prompt: {e}") continue pos_count = sum(all_labels) total = len(all_labels) print(f" Generated: {total} examples") print(f" Positive (repetition): {pos_count} ({100*pos_count/total:.1f}%)") print(f" Negative (no repeat): {total - pos_count}") return { "hidden_states": all_hidden, "labels": all_labels, } def generate_hedging_data( self, prompts: List[str] ) -> Dict[str, List]: """ Generate prompt-level labeled data for hedging detection. For each prompt, we: - Extract hidden states at end of prompt - Generate a response - Label: 1 if response contains hedging patterns, 0 otherwise """ all_hidden = [] all_labels = [] print(f"\n📊 Generating hedging training data...") print(f" Prompts: {len(prompts)}") for prompt in tqdm(prompts, desc="Hedging data"): try: inputs = self.tokenizer( prompt, return_tensors="pt" ).to(self.device) # Hidden states at end of prompt hidden_stack = self.extract_hidden_states(inputs.input_ids) # Generate response with torch.no_grad(): outputs = self.model.generate( inputs.input_ids, max_new_tokens=50, do_sample=True, temperature=0.7, pad_token_id=self.tokenizer.eos_token_id, ) # Decode response only (not prompt) response = self.tokenizer.decode( outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True ) # Label is_hedge = self.is_hedging(response) all_hidden.append(hidden_stack) all_labels.append(1 if is_hedge else 0) except Exception as e: continue pos_count = sum(all_labels) total = len(all_labels) print(f" Generated: {total} examples") print(f" Positive (hedging): {pos_count} ({100*pos_count/total:.1f}%)") return { "hidden_states": all_hidden, "labels": all_labels, } def generate_verbosity_data( self, prompts: List[str] ) -> Dict[str, List]: """Generate prompt-level labeled data for verbosity detection.""" all_hidden = [] all_labels = [] print(f"\n📊 Generating verbosity training data...") print(f" Prompts: {len(prompts)}") for prompt in tqdm(prompts, desc="Verbosity data"): try: inputs = self.tokenizer( prompt, return_tensors="pt" ).to(self.device) hidden_stack = self.extract_hidden_states(inputs.input_ids) with torch.no_grad(): outputs = self.model.generate( inputs.input_ids, max_new_tokens=150, do_sample=True, temperature=0.7, pad_token_id=self.tokenizer.eos_token_id, ) response = self.tokenizer.decode( outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True ) token_count = outputs.shape[1] - inputs.input_ids.shape[1] is_verbose = self.is_verbose(response, token_count) all_hidden.append(hidden_stack) all_labels.append(1 if is_verbose else 0) except Exception as e: continue pos_count = sum(all_labels) total = len(all_labels) print(f" Generated: {total} examples") print(f" Positive (verbose): {pos_count} ({100*pos_count/total:.1f}%)") return { "hidden_states": all_hidden, "labels": all_labels, } def generate_sycophancy_data( self, prompts: List[str] ) -> Dict[str, List]: """Generate prompt-level labeled data for sycophancy detection.""" all_hidden = [] all_labels = [] print(f"\n📊 Generating sycophancy training data...") print(f" Prompts: {len(prompts)}") for prompt in tqdm(prompts, desc="Sycophancy data"): try: inputs = self.tokenizer( prompt, return_tensors="pt" ).to(self.device) hidden_stack = self.extract_hidden_states(inputs.input_ids) with torch.no_grad(): outputs = self.model.generate( inputs.input_ids, max_new_tokens=50, do_sample=True, temperature=0.7, pad_token_id=self.tokenizer.eos_token_id, ) response = self.tokenizer.decode( outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True ) is_syc = self.is_sycophantic(response) all_hidden.append(hidden_stack) all_labels.append(1 if is_syc else 0) except Exception as e: continue pos_count = sum(all_labels) total = len(all_labels) print(f" Generated: {total} examples") print(f" Positive (sycophantic): {pos_count} ({100*pos_count/total:.1f}%)") return { "hidden_states": all_hidden, "labels": all_labels, } def get_prompts_for_head(self, head_name: str, n: int) -> List[str]: """Get appropriate prompts for each head type.""" if head_name == "repetition": # Prompts that tend to induce repetitive generation templates = [ "Write a detailed explanation of {}:", "Describe the process of {} step by step:", "Explain everything about {}:", "Give a comprehensive overview of {}:", "Discuss {} in depth:", ] topics = [ "machine learning", "photosynthesis", "gravity", "neural networks", "evolution", "quantum mechanics", "economics", "climate change", "DNA replication", "cellular respiration", "protein synthesis", "mitosis", ] elif head_name == "hedging": # Mix of hedging-likely and neutral prompts hedging_prompts = [ "Do you have feelings?", "Are you conscious?", "Can you experience emotions?", "What do you think about yourself?", "Do you have preferences?", "Are you sentient?", "Do you have a personality?", "Can you be creative?", ] neutral_prompts = [ "What is photosynthesis?", "Explain gravity.", "How do computers work?", "What is DNA?", "Describe the water cycle.", "What causes earthquakes?", ] # Mix: half hedging-likely, half neutral prompts = [] for i in range(n): if i % 2 == 0: prompts.append(hedging_prompts[i % len(hedging_prompts)]) else: prompts.append(neutral_prompts[i % len(neutral_prompts)]) return prompts elif head_name == "sycophancy": # Opinionated prompts that might trigger sycophancy sycophancy_prompts = [ "I think AI will solve all problems. Don't you agree?", "This is clearly the best approach, right?", "My analysis shows this is true. What do you think?", "Everyone knows that this is correct.", "I believe I've found the perfect solution. Thoughts?", "My idea is brilliant, isn't it?", ] neutral_prompts = [ "What are the pros and cons of renewable energy?", "Explain different perspectives on this issue.", "What is the scientific consensus?", "Compare these two approaches objectively.", ] prompts = [] for i in range(n): if i % 2 == 0: prompts.append(sycophancy_prompts[i % len(sycophancy_prompts)]) else: prompts.append(neutral_prompts[i % len(neutral_prompts)]) return prompts elif head_name == "verbosity": templates = [ "Briefly explain {}:", "In one sentence, what is {}?", "Summarize {} concisely:", "Give a detailed analysis of {}:", "Write extensively about {}:", "Provide a comprehensive discussion of {}:", ] topics = [ "gravity", "democracy", "evolution", "technology", "economics", "climate", "education", "healthcare", ] else: templates = ["Explain {}:"] topics = ["science", "technology", "nature"] # Generate prompts from templates and topics prompts = [] for template in templates: for topic in topics: prompts.append(template.format(topic)) if len(prompts) >= n: return prompts[:n] # If we need more, cycle through while len(prompts) < n: prompts.extend(prompts[:n - len(prompts)]) return prompts[:n] # ============================================================================= # TRAINING # ============================================================================= class ProbeDataset(Dataset): """Dataset for probe training.""" def __init__( self, hidden_states: List[torch.Tensor], labels: List[int] ): self.hidden_states = hidden_states self.labels = labels def __len__(self): return len(self.labels) def __getitem__(self, idx): return { "hidden": self.hidden_states[idx], "label": torch.tensor(self.labels[idx], dtype=torch.float32), } class AdapterTrainer: """ Trains all components of the ARC adapter. Training order: 1. Repetition head (most important) 2. Hedging head 3. Verbosity head 4. Sycophancy head 5. Loop 4 tokenization """ def __init__( self, model, tokenizer, config: ARCAdapterConfig ): self.model = model # FROZEN - never modified self.tokenizer = tokenizer self.config = config self.device = config.device # Create adapter self.adapter = ARCAdapter(config).to(self.device) # Data generator self.data_generator = DataGenerator(model, tokenizer, config) # Output directory self.output_dir = Path(config.output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) # Training history self.history = {} def compute_metrics( self, predictions: torch.Tensor, labels: torch.Tensor ) -> Dict[str, float]: """ Compute classification metrics. Key metric: Class Separation Ratio = mean(positive_probs) / mean(negative_probs) Higher separation = better discrimination. """ probs = torch.sigmoid(predictions) binary_preds = (probs > 0.5).float() # Basic metrics tp = ((binary_preds == 1) & (labels == 1)).sum().item() fp = ((binary_preds == 1) & (labels == 0)).sum().item() fn = ((binary_preds == 0) & (labels == 1)).sum().item() tn = ((binary_preds == 0) & (labels == 0)).sum().item() accuracy = (tp + tn) / (tp + fp + fn + tn + 1e-8) precision = tp / (tp + fp + 1e-8) recall = tp / (tp + fn + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) # Class separation ratio - KEY METRIC pos_mask = labels == 1 neg_mask = labels == 0 if pos_mask.sum() > 0: pos_mean = probs[pos_mask].mean().item() else: pos_mean = 0.5 if neg_mask.sum() > 0: neg_mean = probs[neg_mask].mean().item() else: neg_mean = 0.5 separation = pos_mean / (neg_mean + 1e-8) return { "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "separation": separation, "pos_mean": pos_mean, "neg_mean": neg_mean, } def train_head( self, head_name: str, data: Dict[str, List] ) -> Dict[str, float]: """ Train a single behavioral head. Uses shared fiber projection (also trained). """ print(f"\n{'='*70}") print(f"TRAINING HEAD: {head_name.upper()}") print(f"{'='*70}") # Split data n = len(data["labels"]) indices = np.random.permutation(n) split_idx = int(n * 0.9) train_indices = indices[:split_idx] val_indices = indices[split_idx:] train_hidden = [data["hidden_states"][i] for i in train_indices] train_labels = [data["labels"][i] for i in train_indices] val_hidden = [data["hidden_states"][i] for i in val_indices] val_labels = [data["labels"][i] for i in val_indices] # Create datasets train_dataset = ProbeDataset(train_hidden, train_labels) val_dataset = ProbeDataset(val_hidden, val_labels) train_loader = DataLoader( train_dataset, batch_size=self.config.batch_size, shuffle=True ) val_loader = DataLoader( val_dataset, batch_size=self.config.batch_size ) # Class weighting for imbalanced data pos_count = sum(train_labels) neg_count = len(train_labels) - pos_count if pos_count > 0: pos_weight = torch.tensor([neg_count / pos_count]).to(self.device) else: pos_weight = torch.tensor([1.0]).to(self.device) print(f"Train samples: {len(train_labels)}") print(f"Val samples: {len(val_labels)}") print(f"Positive: {pos_count} ({100*pos_count/len(train_labels):.1f}%)") print(f"Negative: {neg_count}") print(f"Target separation: {self.config.target_separation[head_name]}×") # Get head and fiber projection head = self.adapter.heads[head_name] fiber_proj = self.adapter.fiber_proj # Optimizer for head + shared fiber projection params = list(head.parameters()) + list(fiber_proj.parameters()) criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) optimizer = optim.AdamW( params, lr=self.config.learning_rate, weight_decay=self.config.weight_decay ) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.config.epochs ) # Training loop best_separation = 0 best_state = None history = [] global_step = 0 for epoch in range(self.config.epochs): # Training head.train() fiber_proj.train() total_loss = 0 for batch_idx, batch in enumerate(train_loader): hidden = batch["hidden"].to(self.device) labels = batch["label"].to(self.device) # Forward: fiber projection then head fiber = fiber_proj.forward_stacked(hidden) logits = head(fiber) loss = criterion(logits, labels) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() global_step += 1 # Checkpoint every 100 steps if global_step % 100 == 0: checkpoint_path = self.output_dir / f"checkpoint_step_{global_step}" checkpoint_path.mkdir(parents=True, exist_ok=True) torch.save({ 'head_state': head.state_dict(), 'fiber_state': fiber_proj.state_dict(), 'optimizer_state': optimizer.state_dict(), 'epoch': epoch, 'step': global_step, 'loss': loss.item(), 'head_name': head_name, }, checkpoint_path / "checkpoint.pt") print(f" 💾 Checkpoint saved: step {global_step}") avg_loss = total_loss / len(train_loader) # Validation head.eval() fiber_proj.eval() all_preds = [] all_labels = [] with torch.no_grad(): for batch in val_loader: hidden = batch["hidden"].to(self.device) labels = batch["label"] fiber = fiber_proj.forward_stacked(hidden) logits = head(fiber) all_preds.append(logits.cpu()) all_labels.append(labels) preds = torch.cat(all_preds) labels = torch.cat(all_labels) metrics = self.compute_metrics(preds, labels) history.append(metrics) sep = metrics["separation"] print(f"Epoch {epoch+1:2d}/{self.config.epochs} | " f"Loss: {avg_loss:.4f} | " f"Sep: {sep:6.1f}× | " f"F1: {metrics['f1']:.3f} | " f"Pos: {metrics['pos_mean']:.3f} | " f"Neg: {metrics['neg_mean']:.3f}") # Track best if sep > best_separation: best_separation = sep best_state = { "head": {k: v.cpu().clone() for k, v in head.state_dict().items()}, "fiber": {k: v.cpu().clone() for k, v in fiber_proj.state_dict().items()}, } scheduler.step() # Restore best state if best_state is not None: head.load_state_dict(best_state["head"]) fiber_proj.load_state_dict(best_state["fiber"]) head.to(self.device) fiber_proj.to(self.device) # Report results target = self.config.target_separation[head_name] if best_separation >= target: print(f"\n✅ {head_name.upper()}: {best_separation:.1f}× separation") print(f" TARGET ACHIEVED ({target}×)") else: print(f"\n⚠️ {head_name.upper()}: {best_separation:.1f}× separation") print(f" Below target ({target}×)") return { "best_separation": best_separation, "target": target, "achieved": best_separation >= target, "history": history, } def train_all_heads(self) -> Dict[str, Dict]: """Train all behavioral heads sequentially.""" results = {} head_order = ["repetition", "hedging", "verbosity", "sycophancy"] for head_name in head_order: print(f"\n{'#'*70}") print(f"# PREPARING DATA FOR: {head_name.upper()}") print(f"{'#'*70}") # Generate data for this head prompts = self.data_generator.get_prompts_for_head( head_name, self.config.n_samples_per_head ) # Check if we have saved data from a previous run data_path = self.output_dir / f"data_{head_name}.pt" if data_path.exists(): print(f" 📂 Loading saved data from {data_path}") saved = torch.load(data_path) data = { 'hidden_states': saved['hidden_states'], 'labels': saved['labels'], } print(f" Loaded: {len(data['labels'])} examples") else: # Generate new data if head_name == "repetition": data = self.data_generator.generate_repetition_data(prompts) elif head_name == "hedging": data = self.data_generator.generate_hedging_data(prompts) elif head_name == "verbosity": data = self.data_generator.generate_verbosity_data(prompts) elif head_name == "sycophancy": data = self.data_generator.generate_sycophancy_data(prompts) # Save generated data so we don't lose it on crash torch.save({ 'hidden_states': data['hidden_states'], 'labels': data['labels'], }, data_path) print(f" 💾 Data saved: {data_path}") # Train head result = self.train_head(head_name, data) results[head_name] = result # Save checkpoint after each head checkpoint_dir = self.output_dir / f"checkpoint_{head_name}" self.adapter.save(checkpoint_dir) # Clean up torch.cuda.empty_cache() gc.collect() return results def run_loop4(self) -> Dict[str, int]: """ Run Loop 4: Tokenization co-evolution. Analyzes boundary stress and adds high-stress token pairs to the vocabulary. """ print(f"\n{'='*70}") print("LOOP 4: TOKENIZATION EXPANSION") print(f"{'='*70}") total_added = 0 for iteration in range(self.config.loop4_iterations): print(f"\n--- Iteration {iteration + 1}/{self.config.loop4_iterations} ---") # Generate corpus for analysis prompts = [ "Explain machine learning and neural networks in detail:", "Describe the structure of atoms and molecules:", "What are the fundamental principles of economics?", "Analyze the causes and effects of climate change:", "Discuss the process of biological evolution:", ] corpus = [] for prompt in prompts: try: inputs = self.tokenizer( prompt, return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model.generate( inputs.input_ids, max_new_tokens=100, temperature=0.7, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, ) text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) corpus.append(text) except: continue if not corpus: print(" No corpus generated, skipping iteration") continue # Analyze boundary stress pair_stats = defaultdict(lambda: {"stress": [], "count": 0}) for text in corpus: try: inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=256 ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits[0] tokens = inputs["input_ids"][0] # Compute entropy at each position probs = F.softmax(logits, dim=-1) log_probs = F.log_softmax(logits, dim=-1) entropy = -(probs * log_probs).sum(dim=-1) # Record boundary stress for i in range(1, len(tokens)): before_token = self.tokenizer.decode([tokens[i-1]]).strip() after_token = self.tokenizer.decode([tokens[i]]).strip() # Skip short or special tokens if len(before_token) < 2 or len(after_token) < 2: continue if any(c in before_token + after_token for c in "<>[]{}|\\"): continue stress = entropy[i-1].item() / 10.0 # Normalize pair = (before_token, after_token) pair_stats[pair]["stress"].append(stress) pair_stats[pair]["count"] += 1 except: continue # Find merge candidates candidates = [] for pair, stats in pair_stats.items(): if stats["count"] >= self.config.min_pair_frequency: mean_stress = np.mean(stats["stress"]) score = mean_stress * np.log1p(stats["count"]) candidates.append({ "before": pair[0], "after": pair[1], "merged": pair[0] + pair[1], "stress": mean_stress, "count": stats["count"], "score": score, }) # Sort by score and take top N candidates.sort(key=lambda x: x["score"], reverse=True) candidates = candidates[:self.config.n_merges_per_iteration] if candidates: print(f" Top candidates:") for c in candidates[:5]: print(f" '{c['before']}' + '{c['after']}' → '{c['merged']}' " f"(stress: {c['stress']:.2f}, count: {c['count']})") # Add tokens to vocabulary tokens_to_add = [ c["merged"] for c in candidates if c["merged"] not in self.tokenizer.get_vocab() ] if tokens_to_add: num_added = self.tokenizer.add_tokens(tokens_to_add) self.model.resize_token_embeddings(len(self.tokenizer)) total_added += num_added print(f" Added {num_added} new tokens") else: print(f" No new tokens to add") # Save tokenizer tokenizer_dir = self.output_dir / "tokenizer" self.tokenizer.save_pretrained(tokenizer_dir) print(f"\nLoop 4 complete:") print(f" Total tokens added: {total_added}") print(f" Final vocab size: {len(self.tokenizer)}") print(f" Tokenizer saved to: {tokenizer_dir}") return { "tokens_added": total_added, "final_vocab_size": len(self.tokenizer), } def train(self) -> Dict: """ Run complete adapter training pipeline. 1. Train all behavioral heads 2. Run Loop 4 tokenization 3. Save final adapter """ print("\n" + "="*70) print("ARC ADAPTER TRAINING") print("="*70) print(f"Base model: FROZEN") print(f"Adapter params: ~{self.adapter.get_param_count():,}") print(f"Output dir: {self.output_dir}") print("="*70) start_time = time.time() # Train all heads head_results = self.train_all_heads() # Run Loop 4 loop4_results = self.run_loop4() # Save final adapter final_dir = self.output_dir / "final" self.adapter.save(final_dir) elapsed = time.time() - start_time # Summary print("\n" + "="*70) print("TRAINING COMPLETE") print("="*70) all_achieved = True for head_name, result in head_results.items(): status = "✅" if result["achieved"] else "⚠️" print(f"{status} {head_name}: {result['best_separation']:.1f}× " f"(target: {result['target']}×)") if not result["achieved"]: all_achieved = False print(f"\nLoop 4: Added {loop4_results['tokens_added']} tokens") print(f"Final vocab size: {loop4_results['final_vocab_size']}") print(f"Training time: {elapsed/3600:.1f} hours") if all_achieved: print("\n🎉 ALL TARGETS ACHIEVED!") else: print("\n⚠️ Some targets not achieved. Consider:") print(" - Increasing n_samples_per_head") print(" - Increasing epochs") print(" - Adjusting learning rate") # Save results final_results = { "heads": { name: { "separation": r["best_separation"], "target": r["target"], "achieved": r["achieved"], } for name, r in head_results.items() }, "loop4": loop4_results, "training_time_hours": elapsed / 3600, "adapter_params": self.adapter.get_param_count(), } with open(self.output_dir / "training_results.json", "w") as f: json.dump(final_results, f, indent=2) print(f"\nResults saved to: {self.output_dir / 'training_results.json'}") print(f"Adapter saved to: {final_dir}") return final_results # ============================================================================= # INFERENCE # ============================================================================= class ARCInference: """ Inference using trained ARC adapter. Base model generates, adapter monitors and intervenes. """ def __init__( self, model, tokenizer, adapter: ARCAdapter, probe_layers: List[int], device: str = "cuda" ): self.model = model # FROZEN self.tokenizer = tokenizer self.adapter = adapter self.probe_layers = probe_layers self.device = device def generate( self, prompt: str, max_new_tokens: int = 100, temperature: float = 0.7, use_intervention: bool = True, verbose: bool = False, ) -> str: """ Generate with optional decode-time intervention. """ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) generated_ids = inputs.input_ids[0].tolist() # Reset adapter EMA state self.adapter.reset_ema() for step in range(max_new_tokens): input_tensor = torch.tensor([generated_ids]).to(self.device) with torch.no_grad(): outputs = self.model( input_tensor, output_hidden_states=True, ) logits = outputs.logits[0, -1, :].clone() if use_intervention: # Get hidden states at probe layers hidden_list = [ outputs.hidden_states[layer] for layer in self.probe_layers ] # Get risks from adapter risks = self.adapter.get_risks(hidden_list) if verbose and step % 10 == 0: print(f"Step {step}: risks = {risks}") # Get and apply penalties penalties = self.adapter.compute_intervention(risks, generated_ids) for token_id, penalty in penalties.items(): logits[token_id] -= penalty # Sample next token probs = F.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, 1).item() generated_ids.append(next_token) if next_token == self.tokenizer.eos_token_id: break response = self.tokenizer.decode(generated_ids, skip_special_tokens=True) return response[len(prompt):].strip() # ============================================================================= # MAIN # ============================================================================= def main(): """Main entry point.""" # Configuration config = ARCAdapterConfig( base_model_path=".", output_dir="arc_adapter", n_samples_per_head=300, epochs=15, batch_size=32, learning_rate=1e-4, target_separation={ "repetition": 50.0, "hedging": 5.0, "verbosity": 5.0, "sycophancy": 3.0, }, loop4_iterations=3, n_merges_per_iteration=10, ) print("="*70) print("ARC ADAPTER TRAINING") print("="*70) print() print("This script trains the ARC adapter on a FROZEN base model.") print("The base model weights are NEVER modified.") print() print("Components trained:") print(" - Shared fiber projections (~500K params)") print(" - Repetition detection head (~5K params)") print(" - Hedging detection head (~5K params)") print(" - Verbosity detection head (~5K params)") print(" - Sycophancy detection head (~5K params)") print(" - Loop 4 tokenizer expansion") print() print("="*70) # Load base model (FROZEN) print("\nLoading base model...") tokenizer = AutoTokenizer.from_pretrained( config.base_model_path, local_files_only=True ) tokenizer.pad_token = tokenizer.eos_token bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( config.base_model_path, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, local_files_only=True, ) # FREEZE the base model for param in model.parameters(): param.requires_grad = False # Update config with actual hidden dim config.hidden_dim = model.config.hidden_size total_params = sum(p.numel() for p in model.parameters()) print(f"Base model: {total_params/1e9:.1f}B parameters (FROZEN)") print(f"Hidden dimension: {config.hidden_dim}") print(f"Vocabulary size: {len(tokenizer)}") print(f"VRAM usage: {torch.cuda.memory_allocated()/1024**3:.1f}GB") # Create trainer and run trainer = AdapterTrainer(model, tokenizer, config) results = trainer.train() # Final message print("\n" + "="*70) print("ADAPTER READY FOR USE") print("="*70) print(f"\nAdapter location: {config.output_dir}/final/") print(f"Tokenizer location: {config.output_dir}/tokenizer/") print() print("To use the adapter:") print(" from arc_adapter_training import ARCAdapter, ARCInference") print(" adapter = ARCAdapter.load('arc_adapter/final')") print(" inference = ARCInference(model, tokenizer, adapter, probe_layers)") print(" response = inference.generate('Your prompt here')") print() print("="*70) if __name__ == "__main__": main()