AbstractPhil commited on
Commit
ebd9fd5
Β·
verified Β·
1 Parent(s): aa57a72

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +523 -0
model.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # DEEP BERT v3 β€” Teacher-Distilled Geometric Memory
3
+ #
4
+ # BERT-large (frozen, 512 ctx) student backbone +
5
+ # Geometric memory system (trainable, ~49M) +
6
+ # Projector heads (trainable) aligned to frozen long-context teachers.
7
+ #
8
+ # Teachers (frozen, run once per document):
9
+ # ModernBERT-large: 8192 ctx, 1024 hidden, 28 layers, RoPE + FlashAttn
10
+ # Longformer-large: 4096 ctx, 1024 hidden, 24 layers, sliding + global attn
11
+ #
12
+ # NO Procrustes at runtime. Projectors initialized from static pre-alignment.
13
+ # Bank uses direct cross-attention, no whitening, no alignment transforms.
14
+ # ============================================================================
15
+
16
+ import math
17
+ from dataclasses import dataclass
18
+ from typing import Any, Dict, List, Optional, Tuple
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from transformers import BertModel
24
+
25
+
26
+ # ══════════════════════════════════════════════════════════════════
27
+ # CONFIG
28
+ # ══════════════════════════════════════════════════════════════════
29
+
30
+ @dataclass
31
+ class DeepBertV3Config:
32
+ # Student backbone
33
+ bert_model: str = "google-bert/bert-large-uncased"
34
+ hidden_size: int = 1024
35
+ freeze_bert: bool = True
36
+
37
+ # Memory tokens
38
+ n_memory_tokens: int = 16
39
+
40
+ # Geometric bank
41
+ bank_size: int = 128
42
+ anchor_dim: int = 1024
43
+ n_bank_heads: int = 8
44
+ bank_cross_layers: int = 2
45
+
46
+ # Gate
47
+ gate_type: str = "gru"
48
+
49
+ # Multi-layer extraction β€” full depth profile
50
+ extract_layers: Tuple[int, ...] = (2, 5, 8, 11, 14, 17, 20, 23)
51
+ layer_fusion: str = "learned"
52
+
53
+ # Segment processing
54
+ max_content_tokens: int = 480
55
+ segment_overlap: int = 64
56
+ max_position: int = 512
57
+
58
+ # Teacher specs (for projector sizing)
59
+ n_teachers: int = 2
60
+ teacher_hidden: int = 1024 # both ModernBERT-large and Longformer-large = 1024
61
+
62
+ # Geometric
63
+ cv_target: float = 0.20
64
+
65
+ @property
66
+ def n_extract_layers(self):
67
+ return len(self.extract_layers)
68
+
69
+ @property
70
+ def depth_profile_dim(self):
71
+ return self.n_extract_layers * self.hidden_size
72
+
73
+
74
+ # ══════════════════════════════════════════════════════════════════
75
+ # GEOMETRIC UTILITIES
76
+ # ══════════════════════════════════════════════════════════════════
77
+
78
+ def cayley_menger_vol2(pts):
79
+ with torch.amp.autocast("cuda", enabled=False):
80
+ pts = pts.float()
81
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
82
+ d2 = (diff * diff).sum(-1)
83
+ B, V, _ = d2.shape
84
+ cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
85
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
86
+ s = (-1.0)**V; f = math.factorial(V-1)
87
+ return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
88
+
89
+
90
+ def pentachoron_cv(embeddings, n_samples=16):
91
+ """CV = std/mean of pentachoron volumes."""
92
+ B = embeddings.shape[0]
93
+ if B < 5:
94
+ return torch.tensor(0.0, device=embeddings.device)
95
+ vols = []
96
+ for _ in range(n_samples):
97
+ idx = torch.randperm(B, device=embeddings.device)[:5]
98
+ v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0))
99
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
100
+ stacked = torch.stack(vols)
101
+ return stacked.std() / (stacked.mean() + 1e-8)
102
+
103
+
104
+ # ══════════════════════════════════════════════════════════════════
105
+ # GEOMETRIC MEMORY BANK β€” clean, no Procrustes
106
+ # ══════════════════════════════════════════════════════════════════
107
+
108
+ class GeometricMemoryBank(nn.Module):
109
+ """
110
+ Bank stores compressed depth-profile anchors from each segment.
111
+ Memory tokens query the bank via cross-attention.
112
+ No alignment transform β€” both spaces learned end-to-end.
113
+ """
114
+ def __init__(self, config: DeepBertV3Config):
115
+ super().__init__()
116
+ self.config = config
117
+ self.max_size = config.bank_size
118
+ self.dim = config.anchor_dim
119
+
120
+ # Depth-profile compressor: (B, 8Γ—1024=8192) β†’ (B, 1024)
121
+ depth_dim = config.depth_profile_dim
122
+ self.depth_compressor = nn.Sequential(
123
+ nn.Linear(depth_dim, config.hidden_size * 2),
124
+ nn.GELU(),
125
+ nn.LayerNorm(config.hidden_size * 2),
126
+ nn.Linear(config.hidden_size * 2, config.anchor_dim),
127
+ )
128
+
129
+ # Temporal encoding
130
+ self.temporal_proj = nn.Linear(1, config.anchor_dim, bias=False)
131
+
132
+ # Cross-attention: memory tokens (Q) attend to bank anchors (K, V)
133
+ self.cross_attn = nn.ModuleList([
134
+ nn.MultiheadAttention(config.hidden_size, config.n_bank_heads,
135
+ batch_first=True, dropout=0.1)
136
+ for _ in range(config.bank_cross_layers)
137
+ ])
138
+ self.cross_norms = nn.ModuleList([
139
+ nn.LayerNorm(config.hidden_size)
140
+ for _ in range(config.bank_cross_layers)
141
+ ])
142
+ self.cross_ffns = nn.ModuleList([
143
+ nn.Sequential(
144
+ nn.Linear(config.hidden_size, config.hidden_size * 2),
145
+ nn.GELU(),
146
+ nn.Linear(config.hidden_size * 2, config.hidden_size),
147
+ )
148
+ for _ in range(config.bank_cross_layers)
149
+ ])
150
+ self.ffn_norms = nn.ModuleList([
151
+ nn.LayerNorm(config.hidden_size)
152
+ for _ in range(config.bank_cross_layers)
153
+ ])
154
+
155
+ def init_bank(self, batch_size: int, device: torch.device) -> Dict[str, Any]:
156
+ return {"anchors": torch.zeros(batch_size, 0, self.dim, device=device),
157
+ "n_written": 0}
158
+
159
+ def write(self, bank, content_hidden, attention_mask=None,
160
+ segment_idx=0, depth_cls=None):
161
+ anchors = bank["anchors"]
162
+
163
+ if depth_cls is not None:
164
+ B = depth_cls.shape[0]
165
+ anchor = self.depth_compressor(depth_cls.reshape(B, -1))
166
+ else:
167
+ if attention_mask is not None:
168
+ m = attention_mask.float().unsqueeze(-1)
169
+ pooled = (content_hidden * m).sum(1) / m.sum(1).clamp(min=1)
170
+ else:
171
+ pooled = content_hidden.mean(dim=1)
172
+ anchor = self.depth_compressor(
173
+ pooled.repeat(1, self.config.n_extract_layers))
174
+
175
+ anchor = F.normalize(anchor, dim=-1)
176
+
177
+ # Temporal signal
178
+ t = torch.tensor([[segment_idx]], dtype=anchor.dtype, device=anchor.device)
179
+ anchor = anchor + 0.1 * self.temporal_proj(t / max(self.max_size, 1))
180
+ anchor = F.normalize(anchor, dim=-1)
181
+
182
+ # Append to bank
183
+ anchors = torch.cat([anchors, anchor.unsqueeze(1)], dim=1)
184
+ if anchors.shape[1] > self.max_size:
185
+ anchors = anchors[:, -self.max_size:]
186
+
187
+ return {"anchors": anchors, "n_written": bank["n_written"] + 1,
188
+ "live_anchor": anchor}
189
+
190
+ def read(self, memory_tokens, bank):
191
+ anchors = bank["anchors"]
192
+ if anchors.shape[1] == 0:
193
+ return memory_tokens
194
+
195
+ x = memory_tokens
196
+ for attn, norm, ffn, ffn_norm in zip(
197
+ self.cross_attn, self.cross_norms,
198
+ self.cross_ffns, self.ffn_norms,
199
+ ):
200
+ residual = x
201
+ x, _ = attn(norm(x), anchors, anchors)
202
+ x = residual + x
203
+ residual = x
204
+ x = residual + ffn(ffn_norm(x))
205
+ return x
206
+
207
+
208
+ # ══════════════════════════════════════════════════════════════════
209
+ # DELTA MEMORY GATE
210
+ # ══════════════════════════════════════════════════════════════════
211
+
212
+ class DeltaMemoryGate(nn.Module):
213
+ def __init__(self, config: DeepBertV3Config):
214
+ super().__init__()
215
+ H = config.hidden_size
216
+ self.gate_type = config.gate_type
217
+ if config.gate_type == "gru":
218
+ self.reset_proj = nn.Linear(H * 2, H)
219
+ self.update_proj = nn.Linear(H * 2, H)
220
+ self.candidate_proj = nn.Linear(H * 2, H)
221
+ else:
222
+ self.gate_proj = nn.Linear(H * 2, H)
223
+ self.norm = nn.LayerNorm(H)
224
+
225
+ def forward(self, old, new):
226
+ cat = torch.cat([old, new], dim=-1)
227
+ if self.gate_type == "gru":
228
+ r = torch.sigmoid(self.reset_proj(cat))
229
+ z = torch.sigmoid(self.update_proj(cat))
230
+ h = torch.tanh(self.candidate_proj(torch.cat([r * old, new], dim=-1)))
231
+ out = z * old + (1 - z) * h
232
+ else:
233
+ g = torch.sigmoid(self.gate_proj(cat))
234
+ out = g * old + (1 - g) * new
235
+ return self.norm(out)
236
+
237
+
238
+ # ══════════════════════════════════════════════════════════════════
239
+ # MULTI-LAYER FUSION
240
+ # ══════════════════════════════════════════════════════════════════
241
+
242
+ class LayerFusion(nn.Module):
243
+ def __init__(self, config: DeepBertV3Config):
244
+ super().__init__()
245
+ n = len(config.extract_layers)
246
+ self.weights = nn.Parameter(torch.ones(n) / n)
247
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size)
248
+ self.norm = nn.LayerNorm(config.hidden_size)
249
+
250
+ def forward(self, layer_outputs):
251
+ w = F.softmax(self.weights, dim=0)
252
+ stacked = torch.stack(layer_outputs)
253
+ fused = (stacked * w.view(-1, 1, 1, 1)).sum(0)
254
+ return self.norm(self.proj(fused))
255
+
256
+
257
+ # ══════════════════════════════════════════════════════════════════
258
+ # TEACHER PROJECTOR β€” initialized from static Procrustes
259
+ # ══════════════════════════════════════════════════════════════════
260
+
261
+ class TeacherProjector(nn.Module):
262
+ """
263
+ Projects student output β†’ teacher space. Linear(1024, 1024).
264
+ Initialized from static Procrustes rotation in trainer.
265
+ Fine-tunes during training to account for non-linear differences.
266
+ """
267
+ def __init__(self, student_dim: int, teacher_dim: int, name: str = ""):
268
+ super().__init__()
269
+ self.name = name
270
+ self.proj = nn.Linear(student_dim, teacher_dim, bias=True)
271
+ # Initialize close to identity β€” overwritten by Procrustes in trainer
272
+ nn.init.eye_(self.proj.weight)
273
+ nn.init.zeros_(self.proj.bias)
274
+
275
+ def forward(self, x):
276
+ return self.proj(x)
277
+
278
+ def init_from_procrustes(self, rotation, student_mean, teacher_mean):
279
+ """
280
+ Initialize projector from pre-computed Procrustes alignment.
281
+ rotation: (D, D) orthogonal matrix mapping student β†’ teacher
282
+ student_mean, teacher_mean: (D,) centering vectors
283
+ Sets weight = rotation, bias = teacher_mean - rotation @ student_mean
284
+ """
285
+ with torch.no_grad():
286
+ self.proj.weight.copy_(rotation)
287
+ self.proj.bias.copy_(teacher_mean - rotation @ student_mean)
288
+ print(f" [{self.name}] Procrustes init: |R|={rotation.norm():.3f}")
289
+
290
+
291
+ # ══════════════════════════════════════════════════════════════════
292
+ # DEEP BERT v3 MODEL
293
+ # ══════════════════════════════════════════════════════════════════
294
+
295
+ class DeepBertV3(nn.Module):
296
+ def __init__(self, config: DeepBertV3Config):
297
+ super().__init__()
298
+ self.config = config
299
+
300
+ # ── Frozen BERT backbone ──
301
+ self.bert = BertModel.from_pretrained(
302
+ config.bert_model, add_pooling_layer=False,
303
+ attn_implementation="eager")
304
+ self.bert.config.output_hidden_states = True
305
+ if config.freeze_bert:
306
+ for p in self.bert.parameters():
307
+ p.requires_grad = False
308
+
309
+ # ── Memory system ──
310
+ self.memory_embeddings = nn.Parameter(
311
+ torch.randn(1, config.n_memory_tokens, config.hidden_size) * 0.02)
312
+ self.layer_fusion = LayerFusion(config)
313
+ self.bank = GeometricMemoryBank(config)
314
+ self.gate = DeltaMemoryGate(config)
315
+
316
+ # ── Output heads ──
317
+ self.output_proj = nn.Sequential(
318
+ nn.Linear(config.hidden_size, config.hidden_size),
319
+ nn.GELU(), nn.LayerNorm(config.hidden_size))
320
+ self.memory_output_fusion = nn.Sequential(
321
+ nn.Linear(config.hidden_size * 2, config.hidden_size),
322
+ nn.GELU(),
323
+ nn.Linear(config.hidden_size, config.hidden_size))
324
+
325
+ # ── Teacher projectors (initialized from Procrustes in trainer) ──
326
+ self.proj_modern = TeacherProjector(
327
+ config.hidden_size, config.teacher_hidden, "ModernBERT")
328
+ self.proj_longformer = TeacherProjector(
329
+ config.hidden_size, config.teacher_hidden, "Longformer")
330
+
331
+ @classmethod
332
+ def from_pretrained(cls, config=None, **kwargs):
333
+ if config is None:
334
+ config = DeepBertV3Config(**kwargs)
335
+ model = cls(config)
336
+ n_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
337
+ n_frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
338
+ print(f"DeepBert v3 initialized:")
339
+ print(f" BERT: {n_frozen:,} frozen")
340
+ print(f" Memory + projectors: {n_train:,} trainable")
341
+ print(f" Extract: {config.extract_layers} β†’ {config.depth_profile_dim}-dim anchor")
342
+ print(f" Bank: {config.bank_size} anchors, {config.bank_cross_layers} cross-attn")
343
+ print(f" Memory: {config.n_memory_tokens} tokens, {config.gate_type} gate")
344
+ return model
345
+
346
+ def init_state(self, batch_size, device=None):
347
+ if device is None:
348
+ device = next(self.parameters()).device
349
+ return {
350
+ "memory": self.memory_embeddings.expand(batch_size, -1, -1).clone(),
351
+ "bank": self.bank.init_bank(batch_size, device),
352
+ "segment_idx": 0,
353
+ }
354
+
355
+ def forward(self, input_ids, attention_mask, state):
356
+ B = input_ids.shape[0]
357
+ device = input_ids.device
358
+ n_mem = self.config.n_memory_tokens
359
+ seq_len = input_ids.shape[1]
360
+
361
+ memory_state = state["memory"]
362
+ bank = state["bank"]
363
+ seg_idx = state["segment_idx"]
364
+
365
+ # ── Bank read β†’ enrich memory tokens ──
366
+ memory_tokens = self.bank.read(memory_state, bank)
367
+
368
+ # ── Build BERT input with memory tokens prepended ──
369
+ content_embeds = self.bert.embeddings.word_embeddings(input_ids)
370
+ inputs_embeds = torch.cat([memory_tokens, content_embeds], dim=1)
371
+
372
+ position_ids = torch.cat([
373
+ torch.arange(n_mem, device=device).unsqueeze(0).expand(B, -1),
374
+ torch.arange(n_mem, n_mem + seq_len, device=device).unsqueeze(0).expand(B, -1),
375
+ ], dim=1).clamp(max=self.config.max_position - 1)
376
+
377
+ token_type_ids = torch.cat([
378
+ torch.ones(B, n_mem, dtype=torch.long, device=device),
379
+ torch.zeros(B, seq_len, dtype=torch.long, device=device),
380
+ ], dim=1)
381
+
382
+ full_mask = torch.cat([
383
+ torch.ones(B, n_mem, device=device, dtype=attention_mask.dtype),
384
+ attention_mask,
385
+ ], dim=1)
386
+
387
+ # ── BERT forward ──
388
+ bert_out = self.bert(
389
+ inputs_embeds=inputs_embeds, attention_mask=full_mask,
390
+ position_ids=position_ids, token_type_ids=token_type_ids,
391
+ output_hidden_states=True, return_dict=True)
392
+
393
+ # ── Multi-layer extraction ──
394
+ selected = [bert_out.hidden_states[i + 1] for i in self.config.extract_layers]
395
+ hidden = self.layer_fusion(selected)
396
+ memory_output = hidden[:, :n_mem]
397
+ content_output = hidden[:, n_mem:]
398
+
399
+ # Depth profile: CLS from each extracted layer
400
+ depth_cls = torch.stack([h[:, n_mem, :] for h in selected], dim=1)
401
+
402
+ # ── Gate ──
403
+ new_memory = self.gate(memory_state, memory_output)
404
+
405
+ # ── Bank write ──
406
+ new_bank = self.bank.write(bank, content_output, attention_mask,
407
+ seg_idx, depth_cls=depth_cls)
408
+
409
+ # ── Output: CLS residual ──
410
+ cls_output = self.output_proj(content_output[:, 0])
411
+ memory_delta = self.memory_output_fusion(
412
+ torch.cat([cls_output, new_memory.mean(dim=1)], dim=-1))
413
+ fused = cls_output + memory_delta
414
+
415
+ outputs = {
416
+ "memory_output": fused,
417
+ "cls_output": cls_output,
418
+ "live_anchor": new_bank["live_anchor"],
419
+ "depth_cls": depth_cls,
420
+ "content_output": content_output,
421
+ "memory_tokens": new_memory,
422
+ }
423
+
424
+ # LIVE state β€” trainer controls TBPTT
425
+ new_state = {
426
+ "memory": new_memory,
427
+ "bank": {"anchors": new_bank["anchors"],
428
+ "n_written": new_bank["n_written"],
429
+ "live_anchor": new_bank["live_anchor"]},
430
+ "segment_idx": seg_idx + 1,
431
+ }
432
+ return outputs, new_state
433
+
434
+ @staticmethod
435
+ def detach_state(state):
436
+ return {
437
+ "memory": state["memory"].detach(),
438
+ "bank": {"anchors": state["bank"]["anchors"].detach(),
439
+ "n_written": state["bank"]["n_written"]},
440
+ "segment_idx": state["segment_idx"],
441
+ }
442
+
443
+ def get_trainable_params(self):
444
+ return [p for p in self.parameters() if p.requires_grad]
445
+
446
+ def num_trainable_params(self):
447
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
448
+
449
+
450
+ # ══════════════════════════════════════════════════════════════════
451
+ # STATIC PROCRUSTES β€” computed once, used to init projectors
452
+ # ══════════════════════════════════════════════════════════════════
453
+
454
+ @torch.no_grad()
455
+ def compute_static_procrustes(student_embs, teacher_embs):
456
+ """
457
+ Orthogonal Procrustes: find R that minimizes ||student @ R - teacher||_F.
458
+ Returns rotation R, student_mean, teacher_mean.
459
+ """
460
+ X = student_embs.float()
461
+ Y = teacher_embs.float()
462
+ mu_x, mu_y = X.mean(0), Y.mean(0)
463
+ Xc, Yc = X - mu_x, Y - mu_y
464
+ U, S, Vt = torch.linalg.svd(Xc.T @ Yc)
465
+ R = (U @ Vt).T # (D, D): maps student β†’ teacher
466
+ cos_before = F.cosine_similarity(Xc, Yc, dim=-1).mean()
467
+ cos_after = F.cosine_similarity((Xc @ R.T), Yc, dim=-1).mean()
468
+ print(f" Procrustes: cos {cos_before:.4f} β†’ {cos_after:.4f}")
469
+ return R, mu_x, mu_y
470
+
471
+
472
+ # ═════════════════════════��════════════════════════════════════════
473
+ # SANITY CHECK
474
+ # ══════════════════════════════════════════════════════════════════
475
+
476
+ if __name__ == "__main__":
477
+ print("=" * 70)
478
+ print("DEEP BERT v3 β€” Teacher-Distilled Geometric Memory")
479
+ print("=" * 70)
480
+
481
+ config = DeepBertV3Config()
482
+ model = DeepBertV3.from_pretrained(config)
483
+
484
+ comps = {
485
+ "memory_embeddings": model.memory_embeddings.numel(),
486
+ "layer_fusion": sum(p.numel() for p in model.layer_fusion.parameters()),
487
+ "bank.depth_compressor": sum(p.numel() for p in model.bank.depth_compressor.parameters()),
488
+ "bank.temporal_proj": sum(p.numel() for p in model.bank.temporal_proj.parameters()),
489
+ "bank.cross_attn": sum(p.numel() for p in model.bank.cross_attn.parameters()),
490
+ "bank.cross_ffns": sum(p.numel() for p in model.bank.cross_ffns.parameters()),
491
+ "gate": sum(p.numel() for p in model.gate.parameters()),
492
+ "output_proj": sum(p.numel() for p in model.output_proj.parameters()),
493
+ "memory_output_fusion": sum(p.numel() for p in model.memory_output_fusion.parameters()),
494
+ "proj_modern": sum(p.numel() for p in model.proj_modern.parameters()),
495
+ "proj_longformer": sum(p.numel() for p in model.proj_longformer.parameters()),
496
+ }
497
+ print(f"\n Component breakdown:")
498
+ for k, v in comps.items():
499
+ print(f" {k:30s}: {v:,}")
500
+ print(f" {'TOTAL':30s}: {sum(comps.values()):,}")
501
+
502
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
503
+ model = model.to(device)
504
+
505
+ from transformers import BertTokenizer
506
+ tok = BertTokenizer.from_pretrained(config.bert_model)
507
+ state = model.init_state(1, device)
508
+ texts = [
509
+ "The quick brown fox jumps over the lazy dog near the riverbank.",
510
+ "Meanwhile the cat sat on the mat observing everything carefully.",
511
+ "Both animals eventually fell asleep under the warm afternoon sun.",
512
+ ]
513
+ for i, text in enumerate(texts):
514
+ tokens = tok(text, return_tensors="pt", padding="max_length",
515
+ truncation=True, max_length=config.max_content_tokens)
516
+ with torch.no_grad():
517
+ out, state = model(tokens["input_ids"].to(device),
518
+ tokens["attention_mask"].to(device), state)
519
+ print(f"\n Seg {i+1}: anchor={out['live_anchor'].shape}, "
520
+ f"fused={out['memory_output'].shape}, "
521
+ f"bank={state['bank']['anchors'].shape[1]}")
522
+
523
+ print(f"\nDone.")