Spaces:
Paused
Paused
| """ | |
| AETHER-Net: Main Model | |
| Adaptive Elemental Transformer-Hybrid Efficient Recurrent Network | |
| 25-layer hybrid LLM with 5Γ5 Latin orthogonal magic square layout | |
| and Oheng (δΊθ‘) MoE routing. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from typing import Dict, List, Optional, Tuple | |
| from config import AetherNetConfig, ELEMENTS, LAYER_TO_ELEMENT, ELEMENT_LAYERS | |
| from layers import RMSNorm, build_attention | |
| from oheng_moe import OhengMoE | |
| class AetherNetBlock(nn.Module): | |
| """Single AETHER-Net transformer block. | |
| Structure: | |
| x β RMSNorm β Attention β residual β RMSNorm β OhengMoE β residual β out | |
| """ | |
| def __init__(self, config: AetherNetConfig, layer_idx: int): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.layer_type = config.get_layer_type(layer_idx) | |
| self.element = config.get_layer_element(layer_idx) | |
| # Pre-norm | |
| self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) | |
| self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) | |
| # Attention (type determined by magic square) | |
| self.attention = build_attention(self.layer_type, config) | |
| # MoE FFN with Oheng routing | |
| self.moe = OhengMoE(config, layer_idx) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| element_states: Optional[Dict[str, torch.Tensor]] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| # Attention block with residual | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| hidden_states = self.attention( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| encoder_hidden_states=encoder_hidden_states, | |
| ) | |
| hidden_states = residual + hidden_states | |
| # MoE FFN block with residual | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| hidden_states = self.moe(hidden_states, element_states=element_states) | |
| hidden_states = residual + hidden_states | |
| return hidden_states | |
| class AetherNetModel(nn.Module): | |
| """AETHER-Net Language Model. | |
| Architecture: | |
| - Embedding β 25 Γ AetherNetBlock β RMSNorm β LM Head | |
| - Blocks arranged in 5Γ5 Latin orthogonal magic square | |
| - Oheng MoE with μμ generate and μκ·Ή overcome connections | |
| - Element states flow between element groups for structural self-verification | |
| """ | |
| def __init__(self, config: AetherNetConfig): | |
| super().__init__() | |
| self.config = config | |
| # Token embedding | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) | |
| # 25 transformer blocks | |
| self.layers = nn.ModuleList([ | |
| AetherNetBlock(config, layer_idx=i) | |
| for i in range(config.num_layers) | |
| ]) | |
| # Final norm | |
| self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) | |
| # LM Head | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # Weight tying | |
| if config.tie_word_embeddings: | |
| self.lm_head.weight = self.embed_tokens.weight | |
| # Initialize | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| B, L = input_ids.shape | |
| # Position IDs | |
| if position_ids is None: | |
| position_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1) | |
| # Embed | |
| hidden_states = self.embed_tokens(input_ids) | |
| # ββ Element state tracking for Oheng connections ββ | |
| # Each element group accumulates its output for μμ/μκ·Ή routing | |
| element_states: Dict[str, torch.Tensor] = {} | |
| element_layer_counts: Dict[str, int] = {e: 0 for e in ELEMENTS} | |
| # ββ Forward through 25 layers ββ | |
| for i, layer in enumerate(self.layers): | |
| element = LAYER_TO_ELEMENT[i] | |
| hidden_states = layer( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| element_states=element_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| ) | |
| # Update element state (running average of this element's layer outputs) | |
| element_layer_counts[element] += 1 | |
| count = element_layer_counts[element] | |
| if element in element_states: | |
| # Exponential moving average of element's outputs | |
| element_states[element] = ( | |
| element_states[element] * (count - 1) / count | |
| + hidden_states.detach() / count | |
| ) | |
| else: | |
| element_states[element] = hidden_states.detach() | |
| # Final norm | |
| hidden_states = self.norm(hidden_states) | |
| # LM Head | |
| logits = self.lm_head(hidden_states) | |
| # Loss | |
| loss = None | |
| if labels is not None: | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| loss = nn.functional.cross_entropy( | |
| shift_logits.view(-1, self.config.vocab_size), | |
| shift_labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| return { | |
| "loss": loss, | |
| "logits": logits, | |
| "element_states": element_states, | |
| } | |
| def count_parameters(self) -> Dict[str, int]: | |
| """Count parameters by component.""" | |
| counts = { | |
| "embedding": sum(p.numel() for p in self.embed_tokens.parameters()), | |
| "lm_head": sum(p.numel() for p in self.lm_head.parameters()), | |
| "norm": sum(p.numel() for p in self.norm.parameters()), | |
| } | |
| attn_total = 0 | |
| moe_total = 0 | |
| generate_total = 0 | |
| overcome_total = 0 | |
| for layer in self.layers: | |
| attn_total += sum(p.numel() for p in layer.attention.parameters()) | |
| attn_total += sum(p.numel() for p in layer.input_layernorm.parameters()) | |
| attn_total += sum(p.numel() for p in layer.post_attention_layernorm.parameters()) | |
| moe_total += sum(p.numel() for p in layer.moe.experts.parameters()) | |
| moe_total += sum(p.numel() for p in layer.moe.shared_expert.parameters()) | |
| moe_total += sum(p.numel() for p in layer.moe.router.parameters()) | |
| if layer.moe.generate_boost is not None: | |
| generate_total += sum(p.numel() for p in layer.moe.generate_boost.parameters()) | |
| if layer.moe.overcome_gate is not None: | |
| overcome_total += sum(p.numel() for p in layer.moe.overcome_gate.parameters()) | |
| counts["attention_layers"] = attn_total | |
| counts["moe_experts"] = moe_total | |
| counts["oheng_generate"] = generate_total | |
| counts["oheng_overcome"] = overcome_total | |
| counts["total"] = sum(counts.values()) | |
| return counts | |
| def get_layer_map(self) -> List[Dict]: | |
| """Return human-readable layer map for diagnostics.""" | |
| result = [] | |
| for i, layer in enumerate(self.layers): | |
| result.append({ | |
| "layer": i, | |
| "type": layer.layer_type, | |
| "element": layer.element, | |
| "element_idx": ELEMENTS.index(layer.element), | |
| "phase": i % 5, | |
| "attn_class": layer.attention.__class__.__name__, | |
| }) | |
| return result | |