""" AETHER-Net: Main Model Adaptive Elemental Transformer-Hybrid Efficient Recurrent Network 25-layer hybrid LLM with 5×5 Latin orthogonal magic square layout and Oheng (五行) MoE routing. """ import torch import torch.nn as nn from typing import Dict, List, Optional, Tuple from config import AetherNetConfig, ELEMENTS, LAYER_TO_ELEMENT, ELEMENT_LAYERS from layers import RMSNorm, build_attention from oheng_moe import OhengMoE class AetherNetBlock(nn.Module): """Single AETHER-Net transformer block. Structure: x → RMSNorm → Attention → residual → RMSNorm → OhengMoE → residual → out """ def __init__(self, config: AetherNetConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.layer_type = config.get_layer_type(layer_idx) self.element = config.get_layer_element(layer_idx) # Pre-norm self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) # Attention (type determined by magic square) self.attention = build_attention(self.layer_type, config) # MoE FFN with Oheng routing self.moe = OhengMoE(config, layer_idx) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, element_states: Optional[Dict[str, torch.Tensor]] = None, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Attention block with residual residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.attention( hidden_states, attention_mask=attention_mask, position_ids=position_ids, encoder_hidden_states=encoder_hidden_states, ) hidden_states = residual + hidden_states # MoE FFN block with residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.moe(hidden_states, element_states=element_states) hidden_states = residual + hidden_states return hidden_states class AetherNetModel(nn.Module): """AETHER-Net Language Model. Architecture: - Embedding → 25 × AetherNetBlock → RMSNorm → LM Head - Blocks arranged in 5×5 Latin orthogonal magic square - Oheng MoE with 상생 generate and 상극 overcome connections - Element states flow between element groups for structural self-verification """ def __init__(self, config: AetherNetConfig): super().__init__() self.config = config # Token embedding self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # 25 transformer blocks self.layers = nn.ModuleList([ AetherNetBlock(config, layer_idx=i) for i in range(config.num_layers) ]) # Final norm self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) # LM Head self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Weight tying if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight # Initialize self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: B, L = input_ids.shape # Position IDs if position_ids is None: position_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1) # Embed hidden_states = self.embed_tokens(input_ids) # ── Element state tracking for Oheng connections ── # Each element group accumulates its output for 상생/상극 routing element_states: Dict[str, torch.Tensor] = {} element_layer_counts: Dict[str, int] = {e: 0 for e in ELEMENTS} # ── Forward through 25 layers ── for i, layer in enumerate(self.layers): element = LAYER_TO_ELEMENT[i] hidden_states = layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, element_states=element_states, encoder_hidden_states=encoder_hidden_states, ) # Update element state (running average of this element's layer outputs) element_layer_counts[element] += 1 count = element_layer_counts[element] if element in element_states: # Exponential moving average of element's outputs element_states[element] = ( element_states[element] * (count - 1) / count + hidden_states.detach() / count ) else: element_states[element] = hidden_states.detach() # Final norm hidden_states = self.norm(hidden_states) # LM Head logits = self.lm_head(hidden_states) # Loss loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=-100, ) return { "loss": loss, "logits": logits, "element_states": element_states, } def count_parameters(self) -> Dict[str, int]: """Count parameters by component.""" counts = { "embedding": sum(p.numel() for p in self.embed_tokens.parameters()), "lm_head": sum(p.numel() for p in self.lm_head.parameters()), "norm": sum(p.numel() for p in self.norm.parameters()), } attn_total = 0 moe_total = 0 generate_total = 0 overcome_total = 0 for layer in self.layers: attn_total += sum(p.numel() for p in layer.attention.parameters()) attn_total += sum(p.numel() for p in layer.input_layernorm.parameters()) attn_total += sum(p.numel() for p in layer.post_attention_layernorm.parameters()) moe_total += sum(p.numel() for p in layer.moe.experts.parameters()) moe_total += sum(p.numel() for p in layer.moe.shared_expert.parameters()) moe_total += sum(p.numel() for p in layer.moe.router.parameters()) if layer.moe.generate_boost is not None: generate_total += sum(p.numel() for p in layer.moe.generate_boost.parameters()) if layer.moe.overcome_gate is not None: overcome_total += sum(p.numel() for p in layer.moe.overcome_gate.parameters()) counts["attention_layers"] = attn_total counts["moe_experts"] = moe_total counts["oheng_generate"] = generate_total counts["oheng_overcome"] = overcome_total counts["total"] = sum(counts.values()) return counts def get_layer_map(self) -> List[Dict]: """Return human-readable layer map for diagnostics.""" result = [] for i, layer in enumerate(self.layers): result.append({ "layer": i, "type": layer.layer_type, "element": layer.element, "element_idx": ELEMENTS.index(layer.element), "phase": i % 5, "attn_class": layer.attention.__class__.__name__, }) return result