| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
| class CaptionBertConfig(PretrainedConfig): |
| model_type = "caption_bert" |
|
|
| def __init__( |
| self, |
| vocab_size=30522, |
| max_position_embeddings=8192, |
| hidden_size=384, |
| num_attention_heads=6, |
| num_hidden_layers=6, |
| intermediate_size=1536, |
| output_dim=768, |
| hidden_dropout_prob=0.0, |
| pad_token_id=0, |
| |
| bank_enabled=True, |
| bank_n_experts=5, |
| bank_n_anchors=512, |
| bank_dim=128, |
| bank_cv_target=0.082, |
| **kwargs, |
| ): |
| super().__init__(pad_token_id=pad_token_id, **kwargs) |
| self.vocab_size = vocab_size |
| self.max_position_embeddings = max_position_embeddings |
| self.hidden_size = hidden_size |
| self.num_attention_heads = num_attention_heads |
| self.num_hidden_layers = num_hidden_layers |
| self.intermediate_size = intermediate_size |
| self.output_dim = output_dim |
| self.hidden_dropout_prob = hidden_dropout_prob |
| self.bank_enabled = bank_enabled |
| self.bank_n_experts = bank_n_experts |
| self.bank_n_anchors = bank_n_anchors |
| self.bank_dim = bank_dim |
| self.bank_cv_target = bank_cv_target |
|
|
|
|
| class AlignmentBank(nn.Module): |
| """ |
| Geometric interface layer preserving 5-expert differentiation structure. |
| |
| Trained post-hoc on frozen encoder via GPA + whitened Procrustes. |
| Stores per-expert rotation matrices, whiteners, and means that encode |
| how each expert's geometric perspective differs from the consensus center. |
| |
| Provides geometric context annotations (128-dim) alongside the core |
| 768-dim consensus embedding for downstream heads. |
| """ |
| def __init__(self, d_embed=768, n_experts=5, n_anchors=512, d_bank=128): |
| super().__init__() |
| self.d_embed = d_embed |
| self.n_experts = n_experts |
| self.n_anchors = n_anchors |
| self.d_bank = d_bank |
|
|
| |
| self.expert_rotations = nn.ParameterList([ |
| nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)]) |
| self.expert_whiteners = nn.ParameterList([ |
| nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)]) |
| self.expert_means = nn.ParameterList([ |
| nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)]) |
|
|
| |
| self.anchors = nn.Parameter( |
| F.normalize(torch.randn(n_anchors, d_embed), dim=-1)) |
|
|
| |
| n_cross = n_experts * (n_experts - 1) // 2 |
| geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors |
| self.geo_proj = nn.Sequential( |
| nn.Linear(geo_dim, d_bank * 2), nn.GELU(), nn.LayerNorm(d_bank * 2), |
| nn.Linear(d_bank * 2, d_bank), nn.LayerNorm(d_bank)) |
|
|
| |
| self.register_buffer("target_cv", torch.tensor(0.082)) |
| self.register_buffer("target_mean_cos", torch.tensor(0.0)) |
| self.register_buffer("target_spectral", torch.zeros(50)) |
| self.register_buffer("target_cross_cos_mean", torch.tensor(0.0)) |
| self.register_buffer("target_cross_cos_std", torch.tensor(0.0)) |
| self.register_buffer("target_disagreement_ratio", torch.tensor(0.0)) |
|
|
| def forward(self, embedding): |
| B = embedding.shape[0] |
| emb = embedding.float() |
|
|
| |
| expert_consistency = [] |
| expert_recon = [] |
| expert_projected = [] |
| for i in range(self.n_experts): |
| R = self.expert_rotations[i] |
| W = self.expert_whiteners[i] |
| mu = self.expert_means[i] |
| centered = emb - mu |
| whitened = centered @ W |
| whitened_n = F.normalize(whitened, dim=-1) |
| in_expert = whitened_n @ R.T |
| back = in_expert @ R |
| cos = F.cosine_similarity(whitened_n, back, dim=-1) |
| recon = (whitened_n - back).pow(2).mean(dim=-1) |
| expert_consistency.append(cos) |
| expert_recon.append(recon) |
| expert_projected.append(in_expert) |
|
|
| expert_cos = torch.stack(expert_consistency, dim=-1) |
| expert_mse = torch.stack(expert_recon, dim=-1) |
|
|
| |
| cross_cos = [] |
| for i in range(self.n_experts): |
| for j in range(i + 1, self.n_experts): |
| cc = F.cosine_similarity( |
| expert_projected[i], expert_projected[j], dim=-1) |
| cross_cos.append(cc) |
| cross_features = torch.stack(cross_cos, dim=-1) |
|
|
| |
| per_sample_agreement = expert_cos.mean(dim=-1) |
| per_sample_disagreement = expert_cos.std(dim=-1) |
| disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8) |
|
|
| |
| expert_norms = [] |
| for i in range(self.n_experts): |
| W = self.expert_whiteners[i]; mu = self.expert_means[i] |
| whitened = (emb - mu) @ W |
| expert_norms.append(whitened.norm(dim=-1)) |
| norm_ratio = torch.stack(expert_norms, dim=-1) |
| norm_ratio = norm_ratio / (norm_ratio.mean(dim=-1, keepdim=True) + 1e-8) |
|
|
| |
| anchors_n = F.normalize(self.anchors, dim=-1) |
| anchor_cos = emb @ anchors_n.T |
|
|
| |
| geo_input = torch.cat([ |
| expert_cos, expert_mse, cross_features, |
| disagreement_ratio.unsqueeze(-1), norm_ratio, anchor_cos |
| ], dim=-1) |
| geo_context = self.geo_proj(geo_input) |
| enriched = torch.cat([embedding, geo_context], dim=-1) |
|
|
| |
| diagnostics = { |
| "expert_cos_mean": expert_cos.mean().item(), |
| "expert_cos_std": expert_cos.std().item(), |
| "cross_expert_cos": cross_features.mean().item(), |
| "cross_expert_cos_std": cross_features.std().item(), |
| "anchor_max_cos": anchor_cos.max(dim=-1).values.mean().item(), |
| "anchor_mean_cos": anchor_cos.mean().item(), |
| "disagreement_ratio": disagreement_ratio.mean().item(), |
| "norm_ratio_spread": norm_ratio.std(dim=-1).mean().item(), |
| } |
|
|
| return enriched, geo_context, diagnostics |
|
|
|
|
| class CaptionBertModel(PreTrainedModel): |
| """ |
| Consensus-distilled caption encoder with geometric alignment bank. |
| |
| The encoder produces L2-normalized 768-dim embeddings in the geometric |
| consensus space of 5 BERT-family models (BERT, ModernBERT, RoBERTa, |
| ALBERT, DistilBERT), aligned via Generalized Procrustes Analysis. |
| |
| The alignment bank annotates each embedding with 128-dim geometric |
| context from the 5-expert differentiation structure β per-expert |
| consistency, cross-expert disagreement, and anchor distances. |
| |
| Output fields: |
| last_hidden_state: (B, 768) L2-normalized consensus embedding |
| pooler_output: (B, 768) same (HF compatibility) |
| token_embeddings: (B, L, 384) pre-pooling token representations |
| enriched: (B, 896) embedding + bank geometric context |
| geometric_context: dict expert cos, cross-expert, anchors, etc. |
| hidden_states: tuple per-layer outputs (if requested) |
| """ |
| config_class = CaptionBertConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self.token_emb = nn.Embedding( |
| config.vocab_size, config.hidden_size, |
| padding_idx=config.pad_token_id) |
| self.pos_emb = nn.Embedding( |
| config.max_position_embeddings, config.hidden_size) |
| self.emb_norm = nn.LayerNorm(config.hidden_size) |
| self.emb_drop = nn.Dropout(config.hidden_dropout_prob) |
|
|
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.hidden_size, |
| nhead=config.num_attention_heads, |
| dim_feedforward=config.intermediate_size, |
| dropout=config.hidden_dropout_prob, |
| activation="gelu", |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder( |
| encoder_layer, num_layers=config.num_hidden_layers, |
| enable_nested_tensor=False) |
|
|
| self.output_proj = nn.Sequential( |
| nn.Linear(config.hidden_size, config.hidden_size), |
| nn.GELU(), |
| nn.LayerNorm(config.hidden_size), |
| nn.Linear(config.hidden_size, config.output_dim), |
| ) |
|
|
| |
| if getattr(config, 'bank_enabled', False): |
| self.bank = AlignmentBank( |
| d_embed=config.output_dim, |
| n_experts=config.bank_n_experts, |
| n_anchors=config.bank_n_anchors, |
| d_bank=config.bank_dim, |
| ) |
| else: |
| self.bank = None |
|
|
| self.post_init() |
|
|
| def forward(self, input_ids=None, attention_mask=None, |
| output_hidden_states=False, **kwargs): |
| B, L = input_ids.shape |
| device = input_ids.device |
|
|
| |
| positions = torch.arange(L, device=device).unsqueeze(0) |
| x = self.token_emb(input_ids) + self.pos_emb(positions) |
| x = self.emb_drop(self.emb_norm(x)) |
|
|
| if attention_mask is not None: |
| key_padding_mask = ~attention_mask.bool() |
| else: |
| key_padding_mask = (input_ids == self.config.pad_token_id) |
|
|
| hidden_states = [x] if output_hidden_states else None |
| for layer in self.encoder.layers: |
| x = layer(x, src_key_padding_mask=key_padding_mask) |
| if output_hidden_states: |
| hidden_states.append(x) |
|
|
| |
| if attention_mask is not None: |
| mask = attention_mask.unsqueeze(-1).float() |
| else: |
| mask = (~key_padding_mask).unsqueeze(-1).float() |
| pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) |
| embedding = F.normalize(self.output_proj(pooled), dim=-1) |
|
|
| |
| enriched = None |
| geo_diagnostics = None |
| if self.bank is not None: |
| enriched, _, geo_diagnostics = self.bank(embedding) |
|
|
| |
| result = { |
| 'last_hidden_state': embedding, |
| 'pooler_output': embedding, |
| 'token_embeddings': x, |
| 'enriched': enriched, |
| 'geometric_context': geo_diagnostics, |
| } |
| if output_hidden_states: |
| result['hidden_states'] = tuple(hidden_states) |
|
|
| return type('Output', (), result)() |
|
|
| def encode(self, texts, tokenizer=None, max_length=512, batch_size=128, |
| device=None): |
| """Convenience: raw text β L2-normalized (N, 768) embeddings.""" |
| if isinstance(texts, str): |
| texts = [texts] |
| if tokenizer is None: |
| from transformers import AutoTokenizer |
| tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") |
| if device is None: |
| device = next(self.parameters()).device |
| self.eval() |
| all_emb = [] |
| with torch.no_grad(): |
| for i in range(0, len(texts), batch_size): |
| batch = texts[i:i+batch_size] |
| inputs = tokenizer( |
| batch, max_length=max_length, padding="max_length", |
| truncation=True, return_tensors="pt" |
| ).to(device) |
| out = self(input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"]) |
| all_emb.append(out.last_hidden_state.cpu()) |
| return torch.cat(all_emb) |