mobiusnet-collective / fashionmnist_trainer.py
AbstractPhil's picture
Create fashionmnist_trainer.py
a05e552 verified
"""
Fashion-MNIST Trainer with MobiusCollective
============================================
Train a wide collective of MobiusLens towers on Fashion-MNIST.
Designed for Colab with TensorBoard logging and HuggingFace upload.
License: Apache 2.0
Date: 2025-01-10
Author: AbstractPhil
"""
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Tuple, Dict, Any, Optional
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from datetime import datetime
from pathlib import Path
from safetensors.torch import save_file as save_safetensors
# HuggingFace login for Colab
try:
from huggingface_hub import HfApi, login
from google.colab import userdata
token = userdata.get('HF_TOKEN')
os.environ['HF_TOKEN'] = token
login(token=token)
print("Logged in to HuggingFace via Colab")
HF_AVAILABLE = True
except:
HF_AVAILABLE = False
print("HuggingFace upload disabled (not in Colab or no token)")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# TF32 for Ampere+
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
# ============================================================================
# IMPORTS FROM GEOFRACTAL
# ============================================================================
from geofractal.router.wide_router import WideRouter
from geofractal.router.base_tower import BaseTower
from geofractal.router.components.torch_component import TorchComponent
from geofractal.router.components.lens_component import MobiusLens, TriWaveLens
from geofractal.router.components.fusion_component import AdaptiveFusion
# ============================================================================
# CONV LENS BLOCK
# ============================================================================
class ConvLensBlock(TorchComponent):
"""Depthwise-separable conv with MobiusLens activation."""
def __init__(
self,
name: str,
channels: int,
layer_idx: int,
total_layers: int,
scale_range: Tuple[float, float] = (0.5, 2.5),
use_mobius: bool = True,
):
super().__init__(name)
self.conv = nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False),
nn.Conv2d(channels, channels, 1, bias=False),
nn.BatchNorm2d(channels),
)
if use_mobius:
self.lens = MobiusLens(f'{name}_lens', channels, layer_idx, total_layers, scale_range)
else:
self.lens = TriWaveLens(f'{name}_lens', channels, layer_idx, total_layers, scale_range)
self.residual_weight = nn.Parameter(torch.tensor(0.9))
def forward(self, x: Tensor) -> Tensor:
identity = x
h = self.conv(x)
B, C, H, W = h.shape
h = h.permute(0, 2, 3, 1)
h = self.lens(h)
h = h.permute(0, 3, 1, 2)
rw = torch.sigmoid(self.residual_weight)
return rw * identity + (1 - rw) * h
# ============================================================================
# LENS TOWER
# ============================================================================
class LensTower(BaseTower):
"""Shallow tower covering a segment of the scale continuum."""
def __init__(
self,
name: str,
channels: int,
depth: int,
tower_idx: int,
num_towers: int,
scale_range: Tuple[float, float] = (0.5, 2.5),
use_mobius: bool = True,
):
super().__init__(name, strict=False)
self.tower_idx = tower_idx
self.channels = channels
total_layers = num_towers * depth
start_layer = tower_idx * depth
for i in range(depth):
global_idx = start_layer + i
block = ConvLensBlock(
f'{name}_block_{i}',
channels,
layer_idx=global_idx,
total_layers=total_layers,
scale_range=scale_range,
use_mobius=use_mobius,
)
self.append(block)
self.attach('norm', nn.BatchNorm2d(channels))
def forward(self, x: Tensor) -> Tensor:
for stage in self.stages:
x = stage(x)
return self['norm'](x)
# ============================================================================
# VISION ADAPTIVE FUSION (wraps AdaptiveFusion for BCHW tensors)
# ============================================================================
class VisionAdaptiveFusion(TorchComponent):
"""
Wraps AdaptiveFusion for vision tensors (B, C, H, W).
Permutes to channel-last, fuses, permutes back.
"""
def __init__(self, name: str, num_towers: int, channels: int):
super().__init__(name)
self.num_towers = num_towers
self.fusion = AdaptiveFusion(
f'{name}_adaptive',
num_inputs=num_towers,
in_features=channels,
)
# Output projection (conv for spatial tensors)
self.proj = nn.Sequential(
nn.Conv2d(channels, channels, 1, bias=False),
nn.BatchNorm2d(channels),
)
def forward(self, *opinions: Tensor) -> Tensor:
"""
Args:
*opinions: N tensors of shape (B, C, H, W)
Returns:
Fused tensor of shape (B, C, H, W)
"""
# Permute all to channel-last: (B, H, W, C)
channel_last = [op.permute(0, 2, 3, 1) for op in opinions]
# Fuse using AdaptiveFusion
fused = self.fusion(*channel_last) # (B, H, W, C)
# Permute back: (B, C, H, W)
fused = fused.permute(0, 3, 1, 2)
return self.proj(fused)
# ============================================================================
# MOBIUS COLLECTIVE
# ============================================================================
class MobiusCollective(WideRouter):
"""
Wide collective with MobiusLens towers.
Architecture:
- Light stem (configurable stride)
- Multiple shallow towers in parallel (scale continuum)
- Adaptive fusion + classification head
"""
def __init__(
self,
name: str = 'mobius_collective',
in_channels: int = 1,
channels: int = 64,
num_towers: int = 4,
depth_per_tower: int = 2,
scale_range: Tuple[float, float] = (0.5, 2.5),
use_mobius: bool = True,
num_classes: int = 10,
stem_stride: int = 2,
):
super().__init__(name, auto_discover=True)
self.in_channels = in_channels
self.channels = channels
self.num_towers = num_towers
self.depth_per_tower = depth_per_tower
self.scale_range = scale_range
self.use_mobius = use_mobius
self.num_classes = num_classes
self.stem_stride = stem_stride
# Stem
self.attach('stem', nn.Sequential(
nn.Conv2d(in_channels, channels, 3, stride=stem_stride, padding=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True),
))
# Towers
for i in range(num_towers):
tower = LensTower(
f'tower_{i}',
channels=channels,
depth=depth_per_tower,
tower_idx=i,
num_towers=num_towers,
scale_range=scale_range,
use_mobius=use_mobius,
)
self.attach(f'tower_{i}', tower)
self.discover_towers()
# Fusion (wraps geofractal's AdaptiveFusion for vision tensors)
self.attach('fusion', VisionAdaptiveFusion('fusion', num_towers, channels))
# Head
self.attach('pool', nn.AdaptiveAvgPool2d(1))
self.attach('head', nn.Linear(channels, num_classes))
def forward(self, x: Tensor) -> Tensor:
x = self['stem'](x)
opinions = self.wide_forward(x)
opinion_list = [opinions[f'tower_{i}'] for i in range(self.num_towers)]
fused = self['fusion'](*opinion_list)
fused = self['pool'](fused).flatten(1)
return self['head'](fused)
def get_config(self) -> Dict[str, Any]:
return {
'in_channels': self.in_channels,
'channels': self.channels,
'num_towers': self.num_towers,
'depth_per_tower': self.depth_per_tower,
'scale_range': self.scale_range,
'use_mobius': self.use_mobius,
'num_classes': self.num_classes,
'stem_stride': self.stem_stride,
}
def get_all_lens_stats(self) -> Dict[str, Dict[str, float]]:
"""Return stats from all lenses for logging."""
stats = {}
for tower_name in self.tower_names:
tower = self[tower_name]
for i, stage in enumerate(tower.stages):
key = f"{tower_name}_block_{i}"
stats[key] = stage.lens.get_lens_stats()
return stats
# ============================================================================
# PRESETS
# ============================================================================
PRESETS = {
'fashion_mobius_tiny': {
'channels': 32,
'num_towers': 3,
'depth_per_tower': 2,
'scale_range': (0.5, 2.0),
'use_mobius': True,
},
'fashion_mobius_small': {
'channels': 64,
'num_towers': 4,
'depth_per_tower': 2,
'scale_range': (0.5, 2.5),
'use_mobius': True,
},
'fashion_mobius_base': {
'channels': 96,
'num_towers': 4,
'depth_per_tower': 3,
'scale_range': (0.25, 2.75),
'use_mobius': True,
},
'fashion_tri_small': {
'channels': 64,
'num_towers': 4,
'depth_per_tower': 2,
'scale_range': (0.5, 2.5),
'use_mobius': False,
},
}
# ============================================================================
# DATA
# ============================================================================
def get_fashion_mnist_loaders(data_dir: str = './data', batch_size: int = 128):
"""Get Fashion-MNIST train/val loaders with augmentation."""
train_transform = transforms.Compose([
transforms.RandomCrop(28, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.2860,), (0.3530,)),
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.2860,), (0.3530,)),
])
train_dataset = datasets.FashionMNIST(
data_dir, train=True, download=True, transform=train_transform
)
val_dataset = datasets.FashionMNIST(
data_dir, train=False, download=True, transform=val_transform
)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True, persistent_workers=True
)
val_loader = DataLoader(
val_dataset, batch_size=256, shuffle=False,
num_workers=2, pin_memory=True, persistent_workers=True
)
return train_loader, val_loader
# ============================================================================
# CHECKPOINT MANAGER
# ============================================================================
class CheckpointManager:
"""Handles saving, logging, and optional HF upload."""
def __init__(
self,
output_dir: str,
experiment_name: str,
hf_repo: Optional[str] = None,
save_every: int = 10,
upload_every: int = 20,
):
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.experiment_name = experiment_name
self.hf_repo = hf_repo
self.save_every = save_every
self.upload_every = upload_every
self.run_dir = Path(output_dir) / experiment_name / self.timestamp
self.ckpt_dir = self.run_dir / "checkpoints"
self.tb_dir = self.run_dir / "tensorboard"
self.ckpt_dir.mkdir(parents=True, exist_ok=True)
self.tb_dir.mkdir(parents=True, exist_ok=True)
self.writer = SummaryWriter(log_dir=str(self.tb_dir))
self.hf_api = HfApi() if HF_AVAILABLE and hf_repo else None
self.best_acc = 0.0
self.best_epoch = 0
print(f"Checkpoints: {self.run_dir}")
def save_config(self, model_config: Dict, train_config: Dict):
config = {
'model': model_config,
'training': train_config,
'timestamp': self.timestamp,
}
with open(self.run_dir / "config.json", 'w') as f:
json.dump(config, f, indent=2)
def log_scalars(self, epoch: int, scalars: Dict[str, float], prefix: str = ""):
for name, value in scalars.items():
tag = f"{prefix}/{name}" if prefix else name
self.writer.add_scalar(tag, value, epoch)
def log_lens_stats(self, epoch: int, model: nn.Module):
raw = model._orig_mod if hasattr(model, '_orig_mod') else model
stats = raw.get_all_lens_stats()
for block_name, block_stats in stats.items():
for stat_name, value in block_stats.items():
if isinstance(value, (int, float)):
self.writer.add_scalar(f"lens/{block_name}/{stat_name}", value, epoch)
def save_checkpoint(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler,
epoch: int,
train_acc: float,
val_acc: float,
train_loss: float,
):
raw = model._orig_mod if hasattr(model, '_orig_mod') else model
is_best = val_acc > self.best_acc
if is_best:
self.best_acc = val_acc
self.best_epoch = epoch
# Save best
save_safetensors(raw.state_dict(), str(self.ckpt_dir / "best_model.safetensors"))
torch.save({
'epoch': epoch,
'model_state_dict': raw.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_acc': self.best_acc,
'train_acc': train_acc,
'val_acc': val_acc,
}, self.ckpt_dir / "best_model.pt")
# Periodic save
if epoch % self.save_every == 0:
save_safetensors(raw.state_dict(), str(self.ckpt_dir / f"epoch_{epoch:04d}.safetensors"))
def upload(self, epoch: int, force: bool = False):
if not self.hf_api or not self.hf_repo:
return
if not force and epoch % self.upload_every != 0:
return
try:
hf_path = f"fashion_mnist/{self.experiment_name}/{self.timestamp}"
for f in [self.run_dir / "config.json", self.ckpt_dir / "best_model.safetensors"]:
if f.exists():
self.hf_api.upload_file(
path_or_fileobj=str(f),
path_in_repo=f"{hf_path}/{f.name}",
repo_id=self.hf_repo,
repo_type="model",
)
print(f"Uploaded to {self.hf_repo}/{hf_path}")
except Exception as e:
print(f"Upload failed: {e}")
def close(self):
self.writer.close()
# ============================================================================
# TRAINING
# ============================================================================
def train_fashion_mnist(
preset: str = 'fashion_mobius_small',
epochs: int = 50,
lr: float = 1e-3,
batch_size: int = 128,
output_dir: str = './outputs',
hf_repo: Optional[str] = 'AbstractPhil/mobiusnet-collective',
use_compile: bool = True,
save_every: int = 10,
upload_every: int = 20,
):
"""Train MobiusCollective on Fashion-MNIST."""
config = PRESETS[preset]
print("=" * 70)
print(f"FASHION-MNIST - {preset.upper()}")
print("=" * 70)
print(f"Channels: {config['channels']}")
print(f"Towers: {config['num_towers']} x {config['depth_per_tower']} depth")
print(f"Scale range: {config['scale_range']}")
print(f"Lens: {'Mobius' if config['use_mobius'] else 'TriWave'}")
print()
# Data
train_loader, val_loader = get_fashion_mnist_loaders('./data', batch_size)
# Model
model = MobiusCollective(
name=preset,
in_channels=1, # Fashion-MNIST is grayscale
num_classes=10,
stem_stride=2, # 28x28 -> 14x14
**config,
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total params: {total_params:,}")
# Checkpoint manager
ckpt = CheckpointManager(
output_dir=output_dir,
experiment_name=preset,
hf_repo=hf_repo,
save_every=save_every,
upload_every=upload_every,
)
# Save config
train_config = {
'epochs': epochs,
'lr': lr,
'batch_size': batch_size,
'optimizer': 'AdamW',
'scheduler': 'CosineAnnealingLR',
'total_params': total_params,
}
ckpt.save_config(model.get_config(), train_config)
# Compile
if use_compile and hasattr(torch, 'compile'):
print("Compiling model...")
model = torch.compile(model, mode='reduce-overhead')
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
best_acc = 0.0
for epoch in range(1, epochs + 1):
# Train
model.train()
train_loss, train_correct, train_total = 0, 0, 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}")
for x, y in pbar:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item() * x.size(0)
train_correct += (logits.argmax(1) == y).sum().item()
train_total += x.size(0)
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
# Validate
model.eval()
val_correct, val_total = 0, 0
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
val_correct += (logits.argmax(1) == y).sum().item()
val_total += x.size(0)
# Metrics
train_acc = train_correct / train_total
val_acc = val_correct / val_total
avg_loss = train_loss / train_total
current_lr = scheduler.get_last_lr()[0]
is_best = val_acc > best_acc
if is_best:
best_acc = val_acc
marker = " ★" if is_best else ""
print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}")
# Logging
ckpt.log_scalars(epoch, {
'loss': avg_loss,
'train_acc': train_acc,
'val_acc': val_acc,
'best_acc': best_acc,
'lr': current_lr,
}, prefix='train')
ckpt.log_lens_stats(epoch, model)
# Save
ckpt.save_checkpoint(model, optimizer, scheduler, epoch, train_acc, val_acc, avg_loss)
# Upload
ckpt.upload(epoch)
# Final upload
ckpt.upload(epochs, force=True)
ckpt.close()
print()
print("=" * 70)
print("TRAINING COMPLETE")
print("=" * 70)
print(f"Preset: {preset}")
print(f"Best accuracy: {best_acc:.4f}")
print(f"Params: {total_params:,}")
print(f"Checkpoints: {ckpt.run_dir}")
print("=" * 70)
return model, best_acc
# ============================================================================
# MAIN
# ============================================================================
if __name__ == '__main__':
model, best_acc = train_fashion_mnist(
preset='fashion_mobius_small',
epochs=50,
lr=1e-3,
batch_size=128,
output_dir='./outputs',
hf_repo='AbstractPhil/mobiusnet-collective', # Set to None to disable upload
use_compile=True,
save_every=10,
upload_every=20,
)