| """ |
| Fashion-MNIST Trainer with MobiusCollective |
| ============================================ |
| |
| Train a wide collective of MobiusLens towers on Fashion-MNIST. |
| Designed for Colab with TensorBoard logging and HuggingFace upload. |
| |
| License: Apache 2.0 |
| Date: 2025-01-10 |
| Author: AbstractPhil |
| """ |
|
|
| import os |
| import json |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
| from typing import Tuple, Dict, Any, Optional |
| from torchvision import datasets, transforms |
| from torch.utils.data import DataLoader |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm.auto import tqdm |
| from datetime import datetime |
| from pathlib import Path |
| from safetensors.torch import save_file as save_safetensors |
|
|
| |
| try: |
| from huggingface_hub import HfApi, login |
| from google.colab import userdata |
| token = userdata.get('HF_TOKEN') |
| os.environ['HF_TOKEN'] = token |
| login(token=token) |
| print("Logged in to HuggingFace via Colab") |
| HF_AVAILABLE = True |
| except: |
| HF_AVAILABLE = False |
| print("HuggingFace upload disabled (not in Colab or no token)") |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
|
|
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.set_float32_matmul_precision('high') |
|
|
|
|
| |
| |
| |
|
|
| from geofractal.router.wide_router import WideRouter |
| from geofractal.router.base_tower import BaseTower |
| from geofractal.router.components.torch_component import TorchComponent |
| from geofractal.router.components.lens_component import MobiusLens, TriWaveLens |
| from geofractal.router.components.fusion_component import AdaptiveFusion |
|
|
|
|
| |
| |
| |
|
|
| class ConvLensBlock(TorchComponent): |
| """Depthwise-separable conv with MobiusLens activation.""" |
| |
| def __init__( |
| self, |
| name: str, |
| channels: int, |
| layer_idx: int, |
| total_layers: int, |
| scale_range: Tuple[float, float] = (0.5, 2.5), |
| use_mobius: bool = True, |
| ): |
| super().__init__(name) |
| |
| self.conv = nn.Sequential( |
| nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), |
| nn.Conv2d(channels, channels, 1, bias=False), |
| nn.BatchNorm2d(channels), |
| ) |
| |
| if use_mobius: |
| self.lens = MobiusLens(f'{name}_lens', channels, layer_idx, total_layers, scale_range) |
| else: |
| self.lens = TriWaveLens(f'{name}_lens', channels, layer_idx, total_layers, scale_range) |
| |
| self.residual_weight = nn.Parameter(torch.tensor(0.9)) |
| |
| def forward(self, x: Tensor) -> Tensor: |
| identity = x |
| h = self.conv(x) |
| B, C, H, W = h.shape |
| h = h.permute(0, 2, 3, 1) |
| h = self.lens(h) |
| h = h.permute(0, 3, 1, 2) |
| rw = torch.sigmoid(self.residual_weight) |
| return rw * identity + (1 - rw) * h |
|
|
|
|
| |
| |
| |
|
|
| class LensTower(BaseTower): |
| """Shallow tower covering a segment of the scale continuum.""" |
| |
| def __init__( |
| self, |
| name: str, |
| channels: int, |
| depth: int, |
| tower_idx: int, |
| num_towers: int, |
| scale_range: Tuple[float, float] = (0.5, 2.5), |
| use_mobius: bool = True, |
| ): |
| super().__init__(name, strict=False) |
| |
| self.tower_idx = tower_idx |
| self.channels = channels |
| |
| total_layers = num_towers * depth |
| start_layer = tower_idx * depth |
| |
| for i in range(depth): |
| global_idx = start_layer + i |
| block = ConvLensBlock( |
| f'{name}_block_{i}', |
| channels, |
| layer_idx=global_idx, |
| total_layers=total_layers, |
| scale_range=scale_range, |
| use_mobius=use_mobius, |
| ) |
| self.append(block) |
| |
| self.attach('norm', nn.BatchNorm2d(channels)) |
| |
| def forward(self, x: Tensor) -> Tensor: |
| for stage in self.stages: |
| x = stage(x) |
| return self['norm'](x) |
|
|
|
|
| |
| |
| |
|
|
| class VisionAdaptiveFusion(TorchComponent): |
| """ |
| Wraps AdaptiveFusion for vision tensors (B, C, H, W). |
| |
| Permutes to channel-last, fuses, permutes back. |
| """ |
| |
| def __init__(self, name: str, num_towers: int, channels: int): |
| super().__init__(name) |
| |
| self.num_towers = num_towers |
| self.fusion = AdaptiveFusion( |
| f'{name}_adaptive', |
| num_inputs=num_towers, |
| in_features=channels, |
| ) |
| |
| |
| self.proj = nn.Sequential( |
| nn.Conv2d(channels, channels, 1, bias=False), |
| nn.BatchNorm2d(channels), |
| ) |
| |
| def forward(self, *opinions: Tensor) -> Tensor: |
| """ |
| Args: |
| *opinions: N tensors of shape (B, C, H, W) |
| Returns: |
| Fused tensor of shape (B, C, H, W) |
| """ |
| |
| channel_last = [op.permute(0, 2, 3, 1) for op in opinions] |
| |
| |
| fused = self.fusion(*channel_last) |
| |
| |
| fused = fused.permute(0, 3, 1, 2) |
| |
| return self.proj(fused) |
|
|
|
|
| |
| |
| |
|
|
| class MobiusCollective(WideRouter): |
| """ |
| Wide collective with MobiusLens towers. |
| |
| Architecture: |
| - Light stem (configurable stride) |
| - Multiple shallow towers in parallel (scale continuum) |
| - Adaptive fusion + classification head |
| """ |
| |
| def __init__( |
| self, |
| name: str = 'mobius_collective', |
| in_channels: int = 1, |
| channels: int = 64, |
| num_towers: int = 4, |
| depth_per_tower: int = 2, |
| scale_range: Tuple[float, float] = (0.5, 2.5), |
| use_mobius: bool = True, |
| num_classes: int = 10, |
| stem_stride: int = 2, |
| ): |
| super().__init__(name, auto_discover=True) |
| |
| self.in_channels = in_channels |
| self.channels = channels |
| self.num_towers = num_towers |
| self.depth_per_tower = depth_per_tower |
| self.scale_range = scale_range |
| self.use_mobius = use_mobius |
| self.num_classes = num_classes |
| self.stem_stride = stem_stride |
| |
| |
| self.attach('stem', nn.Sequential( |
| nn.Conv2d(in_channels, channels, 3, stride=stem_stride, padding=1, bias=False), |
| nn.BatchNorm2d(channels), |
| nn.ReLU(inplace=True), |
| )) |
| |
| |
| for i in range(num_towers): |
| tower = LensTower( |
| f'tower_{i}', |
| channels=channels, |
| depth=depth_per_tower, |
| tower_idx=i, |
| num_towers=num_towers, |
| scale_range=scale_range, |
| use_mobius=use_mobius, |
| ) |
| self.attach(f'tower_{i}', tower) |
| |
| self.discover_towers() |
| |
| |
| self.attach('fusion', VisionAdaptiveFusion('fusion', num_towers, channels)) |
| |
| |
| self.attach('pool', nn.AdaptiveAvgPool2d(1)) |
| self.attach('head', nn.Linear(channels, num_classes)) |
| |
| def forward(self, x: Tensor) -> Tensor: |
| x = self['stem'](x) |
| |
| opinions = self.wide_forward(x) |
| opinion_list = [opinions[f'tower_{i}'] for i in range(self.num_towers)] |
| |
| fused = self['fusion'](*opinion_list) |
| fused = self['pool'](fused).flatten(1) |
| |
| return self['head'](fused) |
| |
| def get_config(self) -> Dict[str, Any]: |
| return { |
| 'in_channels': self.in_channels, |
| 'channels': self.channels, |
| 'num_towers': self.num_towers, |
| 'depth_per_tower': self.depth_per_tower, |
| 'scale_range': self.scale_range, |
| 'use_mobius': self.use_mobius, |
| 'num_classes': self.num_classes, |
| 'stem_stride': self.stem_stride, |
| } |
| |
| def get_all_lens_stats(self) -> Dict[str, Dict[str, float]]: |
| """Return stats from all lenses for logging.""" |
| stats = {} |
| for tower_name in self.tower_names: |
| tower = self[tower_name] |
| for i, stage in enumerate(tower.stages): |
| key = f"{tower_name}_block_{i}" |
| stats[key] = stage.lens.get_lens_stats() |
| return stats |
|
|
|
|
| |
| |
| |
|
|
| PRESETS = { |
| 'fashion_mobius_tiny': { |
| 'channels': 32, |
| 'num_towers': 3, |
| 'depth_per_tower': 2, |
| 'scale_range': (0.5, 2.0), |
| 'use_mobius': True, |
| }, |
| 'fashion_mobius_small': { |
| 'channels': 64, |
| 'num_towers': 4, |
| 'depth_per_tower': 2, |
| 'scale_range': (0.5, 2.5), |
| 'use_mobius': True, |
| }, |
| 'fashion_mobius_base': { |
| 'channels': 96, |
| 'num_towers': 4, |
| 'depth_per_tower': 3, |
| 'scale_range': (0.25, 2.75), |
| 'use_mobius': True, |
| }, |
| 'fashion_tri_small': { |
| 'channels': 64, |
| 'num_towers': 4, |
| 'depth_per_tower': 2, |
| 'scale_range': (0.5, 2.5), |
| 'use_mobius': False, |
| }, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def get_fashion_mnist_loaders(data_dir: str = './data', batch_size: int = 128): |
| """Get Fashion-MNIST train/val loaders with augmentation.""" |
| |
| train_transform = transforms.Compose([ |
| transforms.RandomCrop(28, padding=4), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize((0.2860,), (0.3530,)), |
| ]) |
| |
| val_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.2860,), (0.3530,)), |
| ]) |
| |
| train_dataset = datasets.FashionMNIST( |
| data_dir, train=True, download=True, transform=train_transform |
| ) |
| val_dataset = datasets.FashionMNIST( |
| data_dir, train=False, download=True, transform=val_transform |
| ) |
| |
| train_loader = DataLoader( |
| train_dataset, batch_size=batch_size, shuffle=True, |
| num_workers=4, pin_memory=True, persistent_workers=True |
| ) |
| val_loader = DataLoader( |
| val_dataset, batch_size=256, shuffle=False, |
| num_workers=2, pin_memory=True, persistent_workers=True |
| ) |
| |
| return train_loader, val_loader |
|
|
|
|
| |
| |
| |
|
|
| class CheckpointManager: |
| """Handles saving, logging, and optional HF upload.""" |
| |
| def __init__( |
| self, |
| output_dir: str, |
| experiment_name: str, |
| hf_repo: Optional[str] = None, |
| save_every: int = 10, |
| upload_every: int = 20, |
| ): |
| self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| self.experiment_name = experiment_name |
| self.hf_repo = hf_repo |
| self.save_every = save_every |
| self.upload_every = upload_every |
| |
| self.run_dir = Path(output_dir) / experiment_name / self.timestamp |
| self.ckpt_dir = self.run_dir / "checkpoints" |
| self.tb_dir = self.run_dir / "tensorboard" |
| |
| self.ckpt_dir.mkdir(parents=True, exist_ok=True) |
| self.tb_dir.mkdir(parents=True, exist_ok=True) |
| |
| self.writer = SummaryWriter(log_dir=str(self.tb_dir)) |
| self.hf_api = HfApi() if HF_AVAILABLE and hf_repo else None |
| |
| self.best_acc = 0.0 |
| self.best_epoch = 0 |
| |
| print(f"Checkpoints: {self.run_dir}") |
| |
| def save_config(self, model_config: Dict, train_config: Dict): |
| config = { |
| 'model': model_config, |
| 'training': train_config, |
| 'timestamp': self.timestamp, |
| } |
| with open(self.run_dir / "config.json", 'w') as f: |
| json.dump(config, f, indent=2) |
| |
| def log_scalars(self, epoch: int, scalars: Dict[str, float], prefix: str = ""): |
| for name, value in scalars.items(): |
| tag = f"{prefix}/{name}" if prefix else name |
| self.writer.add_scalar(tag, value, epoch) |
| |
| def log_lens_stats(self, epoch: int, model: nn.Module): |
| raw = model._orig_mod if hasattr(model, '_orig_mod') else model |
| stats = raw.get_all_lens_stats() |
| for block_name, block_stats in stats.items(): |
| for stat_name, value in block_stats.items(): |
| if isinstance(value, (int, float)): |
| self.writer.add_scalar(f"lens/{block_name}/{stat_name}", value, epoch) |
| |
| def save_checkpoint( |
| self, |
| model: nn.Module, |
| optimizer: torch.optim.Optimizer, |
| scheduler, |
| epoch: int, |
| train_acc: float, |
| val_acc: float, |
| train_loss: float, |
| ): |
| raw = model._orig_mod if hasattr(model, '_orig_mod') else model |
| is_best = val_acc > self.best_acc |
| |
| if is_best: |
| self.best_acc = val_acc |
| self.best_epoch = epoch |
| |
| |
| save_safetensors(raw.state_dict(), str(self.ckpt_dir / "best_model.safetensors")) |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': raw.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'best_acc': self.best_acc, |
| 'train_acc': train_acc, |
| 'val_acc': val_acc, |
| }, self.ckpt_dir / "best_model.pt") |
| |
| |
| if epoch % self.save_every == 0: |
| save_safetensors(raw.state_dict(), str(self.ckpt_dir / f"epoch_{epoch:04d}.safetensors")) |
| |
| def upload(self, epoch: int, force: bool = False): |
| if not self.hf_api or not self.hf_repo: |
| return |
| if not force and epoch % self.upload_every != 0: |
| return |
| |
| try: |
| hf_path = f"fashion_mnist/{self.experiment_name}/{self.timestamp}" |
| |
| for f in [self.run_dir / "config.json", self.ckpt_dir / "best_model.safetensors"]: |
| if f.exists(): |
| self.hf_api.upload_file( |
| path_or_fileobj=str(f), |
| path_in_repo=f"{hf_path}/{f.name}", |
| repo_id=self.hf_repo, |
| repo_type="model", |
| ) |
| print(f"Uploaded to {self.hf_repo}/{hf_path}") |
| except Exception as e: |
| print(f"Upload failed: {e}") |
| |
| def close(self): |
| self.writer.close() |
|
|
|
|
| |
| |
| |
|
|
| def train_fashion_mnist( |
| preset: str = 'fashion_mobius_small', |
| epochs: int = 50, |
| lr: float = 1e-3, |
| batch_size: int = 128, |
| output_dir: str = './outputs', |
| hf_repo: Optional[str] = 'AbstractPhil/mobiusnet-collective', |
| use_compile: bool = True, |
| save_every: int = 10, |
| upload_every: int = 20, |
| ): |
| """Train MobiusCollective on Fashion-MNIST.""" |
| |
| config = PRESETS[preset] |
| |
| print("=" * 70) |
| print(f"FASHION-MNIST - {preset.upper()}") |
| print("=" * 70) |
| print(f"Channels: {config['channels']}") |
| print(f"Towers: {config['num_towers']} x {config['depth_per_tower']} depth") |
| print(f"Scale range: {config['scale_range']}") |
| print(f"Lens: {'Mobius' if config['use_mobius'] else 'TriWave'}") |
| print() |
| |
| |
| train_loader, val_loader = get_fashion_mnist_loaders('./data', batch_size) |
| |
| |
| model = MobiusCollective( |
| name=preset, |
| in_channels=1, |
| num_classes=10, |
| stem_stride=2, |
| **config, |
| ).to(device) |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Total params: {total_params:,}") |
| |
| |
| ckpt = CheckpointManager( |
| output_dir=output_dir, |
| experiment_name=preset, |
| hf_repo=hf_repo, |
| save_every=save_every, |
| upload_every=upload_every, |
| ) |
| |
| |
| train_config = { |
| 'epochs': epochs, |
| 'lr': lr, |
| 'batch_size': batch_size, |
| 'optimizer': 'AdamW', |
| 'scheduler': 'CosineAnnealingLR', |
| 'total_params': total_params, |
| } |
| ckpt.save_config(model.get_config(), train_config) |
| |
| |
| if use_compile and hasattr(torch, 'compile'): |
| print("Compiling model...") |
| model = torch.compile(model, mode='reduce-overhead') |
| |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
| |
| best_acc = 0.0 |
| |
| for epoch in range(1, epochs + 1): |
| |
| model.train() |
| train_loss, train_correct, train_total = 0, 0, 0 |
| |
| pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}") |
| for x, y in pbar: |
| x, y = x.to(device), y.to(device) |
| |
| optimizer.zero_grad() |
| logits = model(x) |
| loss = F.cross_entropy(logits, y) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| |
| train_loss += loss.item() * x.size(0) |
| train_correct += (logits.argmax(1) == y).sum().item() |
| train_total += x.size(0) |
| |
| pbar.set_postfix(loss=f"{loss.item():.4f}") |
| |
| scheduler.step() |
| |
| |
| model.eval() |
| val_correct, val_total = 0, 0 |
| with torch.no_grad(): |
| for x, y in val_loader: |
| x, y = x.to(device), y.to(device) |
| logits = model(x) |
| val_correct += (logits.argmax(1) == y).sum().item() |
| val_total += x.size(0) |
| |
| |
| train_acc = train_correct / train_total |
| val_acc = val_correct / val_total |
| avg_loss = train_loss / train_total |
| current_lr = scheduler.get_last_lr()[0] |
| |
| is_best = val_acc > best_acc |
| if is_best: |
| best_acc = val_acc |
| |
| marker = " ★" if is_best else "" |
| print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | " |
| f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}") |
| |
| |
| ckpt.log_scalars(epoch, { |
| 'loss': avg_loss, |
| 'train_acc': train_acc, |
| 'val_acc': val_acc, |
| 'best_acc': best_acc, |
| 'lr': current_lr, |
| }, prefix='train') |
| |
| ckpt.log_lens_stats(epoch, model) |
| |
| |
| ckpt.save_checkpoint(model, optimizer, scheduler, epoch, train_acc, val_acc, avg_loss) |
| |
| |
| ckpt.upload(epoch) |
| |
| |
| ckpt.upload(epochs, force=True) |
| ckpt.close() |
| |
| print() |
| print("=" * 70) |
| print("TRAINING COMPLETE") |
| print("=" * 70) |
| print(f"Preset: {preset}") |
| print(f"Best accuracy: {best_acc:.4f}") |
| print(f"Params: {total_params:,}") |
| print(f"Checkpoints: {ckpt.run_dir}") |
| print("=" * 70) |
| |
| return model, best_acc |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| model, best_acc = train_fashion_mnist( |
| preset='fashion_mobius_small', |
| epochs=50, |
| lr=1e-3, |
| batch_size=128, |
| output_dir='./outputs', |
| hf_repo='AbstractPhil/mobiusnet-collective', |
| use_compile=True, |
| save_every=10, |
| upload_every=20, |
| ) |