| | """ |
| | BEATRIX FLOW-MATCHING - CIFAR-10 (T5 Text Encoder) |
| | =================================================== |
| | |
| | SD 1.5 VAE + Flan-T5-Large text encoder |
| | Dual tower collectives: vision towers + text towers |
| | |
| | Text prompts for CIFAR-10 classes: |
| | "a photo of an airplane" |
| | "a photo of an automobile" |
| | etc. |
| | |
| | Requirements: |
| | pip install transformers diffusers torchvision tqdm |
| | pip install git+https://github.com/AbstractEyes/geofractal |
| | |
| | Currently running like a turtle, will optimize tomorrow. |
| | |
| | apache 2.0 license |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import math |
| | from dataclasses import dataclass |
| | from typing import Dict, Tuple, Optional, List |
| | from pathlib import Path |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch import Tensor |
| | from torch.utils.data import DataLoader, Dataset |
| | from torchvision import datasets, transforms |
| | from torchvision.utils import make_grid, save_image |
| | from huggingface_hub import HfApi, upload_file, create_repo |
| | import json |
| | from tqdm import tqdm |
| |
|
| | |
| | |
| | |
| |
|
| | from geofractal.router.wide_router import WideRouter |
| | from geofractal.router.prefab.agatha.beatrix_tension_oscillator import ( |
| | BeatrixOscillator, |
| | ScheduleType, |
| | ) |
| | from geofractal.router.prefab.geometric_tower_builder import ( |
| | TowerConfig, |
| | FusionType, |
| | ConfigurableCollective, |
| | build_tower_collective, |
| | preset_pos_neg_pairs, |
| | ) |
| | from geofractal.router.prefab.geometric_conv_tower_builder import ( |
| | ConvTowerConfig, |
| | ConvTowerCollective, |
| | build_conv_collective, |
| | preset_conv_pos_neg, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | CIFAR10_PROMPTS = [ |
| | "a photo of an airplane", |
| | "a photo of an automobile", |
| | "a photo of a bird", |
| | "a photo of a cat", |
| | "a photo of a deer", |
| | "a photo of a dog", |
| | "a photo of a frog", |
| | "a photo of a horse", |
| | "a photo of a ship", |
| | "a photo of a truck", |
| | ] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class SD15VAE(nn.Module): |
| | def __init__(self, freeze: bool = True): |
| | super().__init__() |
| | from diffusers import AutoencoderKL |
| | |
| | self.vae = AutoencoderKL.from_pretrained( |
| | "runwayml/stable-diffusion-v1-5", |
| | subfolder="vae", |
| | torch_dtype=torch.float32, |
| | ) |
| | |
| | if freeze: |
| | self.vae.eval() |
| | for p in self.vae.parameters(): |
| | p.requires_grad = False |
| | |
| | self.scale_factor = 0.18215 |
| | |
| | @torch.no_grad() |
| | def encode(self, x: Tensor) -> Tensor: |
| | return self.vae.encode(x).latent_dist.sample() * self.scale_factor |
| | |
| | @torch.no_grad() |
| | def decode(self, z: Tensor) -> Tensor: |
| | return self.vae.decode(z / self.scale_factor).sample |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class T5TextEncoder(nn.Module): |
| | """Flan-T5 encoder with bottleneck projection.""" |
| | |
| | def __init__( |
| | self, |
| | model_name: str = "google/flan-t5-xl", |
| | freeze: bool = True, |
| | max_length: int = 77, |
| | bottleneck_dim: int = 256, |
| | ): |
| | super().__init__() |
| | from transformers import T5EncoderModel, T5Tokenizer |
| | |
| | self.tokenizer = T5Tokenizer.from_pretrained(model_name) |
| | self.encoder = T5EncoderModel.from_pretrained(model_name) |
| | self.max_length = max_length |
| | self.raw_dim = self.encoder.config.d_model |
| | self.output_dim = bottleneck_dim |
| | |
| | |
| | self.bottleneck = nn.Sequential( |
| | nn.Linear(self.raw_dim, bottleneck_dim), |
| | nn.GELU(), |
| | nn.Linear(bottleneck_dim, bottleneck_dim), |
| | ) |
| | |
| | if freeze: |
| | self.encoder.eval() |
| | for p in self.encoder.parameters(): |
| | p.requires_grad = False |
| | |
| | |
| | @torch.no_grad() |
| | def forward(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]: |
| | """ |
| | Encode text prompts with bottleneck. |
| | |
| | Returns: |
| | sequence: [B, L, bottleneck_dim] - compressed sequence embeddings |
| | pooled: [B, bottleneck_dim] - compressed mean pooled embedding |
| | """ |
| | tokens = self.tokenizer( |
| | texts, |
| | padding="max_length", |
| | max_length=self.max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | |
| | input_ids = tokens.input_ids.to(device) |
| | attention_mask = tokens.attention_mask.to(device) |
| | |
| | outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| | sequence_raw = outputs.last_hidden_state |
| | |
| | |
| | sequence = self.bottleneck(sequence_raw) |
| | |
| | |
| | mask_expanded = attention_mask.unsqueeze(-1).float() |
| | pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) |
| | |
| | return sequence, pooled |
| | |
| | @torch.no_grad() |
| | def encode_raw(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]: |
| | """ |
| | Encode text prompts WITHOUT bottleneck (for caching raw embeddings). |
| | |
| | Returns: |
| | sequence: [B, L, raw_dim] - raw T5 embeddings |
| | pooled: [B, raw_dim] - raw mean pooled embedding |
| | """ |
| | tokens = self.tokenizer( |
| | texts, |
| | padding="max_length", |
| | max_length=self.max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | |
| | input_ids = tokens.input_ids.to(device) |
| | attention_mask = tokens.attention_mask.to(device) |
| | |
| | outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| | sequence = outputs.last_hidden_state |
| | |
| | |
| | mask_expanded = attention_mask.unsqueeze(-1).float() |
| | pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) |
| | |
| | return sequence, pooled |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class CachedCIFAR10T5(Dataset): |
| | """ |
| | Pre-cached CIFAR-10 with VAE latents. |
| | T5 embeddings are computed per-class (not per-image). |
| | """ |
| | |
| | T5_MODEL = "google/flan-t5-xl" |
| | |
| | def __init__( |
| | self, |
| | train: bool = True, |
| | image_size: int = 256, |
| | cache_dir: str = "./cache", |
| | device: str = "cuda", |
| | ): |
| | self.train = train |
| | |
| | t5_suffix = self.T5_MODEL.replace("/", "_") |
| | self.cache_path = Path(cache_dir) / f"cifar10_{t5_suffix}_{'train' if train else 'val'}_{image_size}.pt" |
| | |
| | if self.cache_path.exists(): |
| | print(f"Loading cache: {self.cache_path}") |
| | cache = torch.load(self.cache_path, weights_only=False) |
| | self.latents = cache['latents'] |
| | self.labels = cache['labels'] |
| | self.text_sequence = cache['text_sequence'] |
| | self.text_pooled = cache['text_pooled'] |
| | self.text_dim = cache.get('text_dim', self.text_pooled.shape[-1]) |
| | else: |
| | print(f"Building cache for {'train' if train else 'val'} set...") |
| | self._build_cache(image_size, device) |
| | |
| | def _build_cache(self, image_size: int, device: str): |
| | |
| | print(" Loading VAE...") |
| | vae = SD15VAE(freeze=True).to(device) |
| | print(f" Loading T5 ({self.T5_MODEL})...") |
| | t5 = T5TextEncoder(model_name=self.T5_MODEL, freeze=True).to(device) |
| | |
| | |
| | print(f" Encoding text prompts (T5 raw_dim={t5.raw_dim})...") |
| | text_seq, text_pool = t5.encode_raw(CIFAR10_PROMPTS, device) |
| | self.text_sequence = text_seq.cpu() |
| | self.text_pooled = text_pool.cpu() |
| | self.text_dim = t5.raw_dim |
| | |
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize((image_size, image_size)), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| | ]) |
| | |
| | dataset = datasets.CIFAR10('./data', train=self.train, download=True, transform=transform) |
| | loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True) |
| | |
| | all_latents, all_labels = [], [] |
| | |
| | print(" Encoding images...") |
| | with torch.no_grad(): |
| | for images, labels in tqdm(loader, desc=" Caching", leave=False): |
| | images = images.to(device) |
| | all_latents.append(vae.encode(images).cpu()) |
| | all_labels.append(labels) |
| | |
| | self.latents = torch.cat(all_latents, dim=0) |
| | self.labels = torch.cat(all_labels, dim=0) |
| | |
| | del vae, t5 |
| | torch.cuda.empty_cache() |
| | |
| | |
| | self.cache_path.parent.mkdir(parents=True, exist_ok=True) |
| | torch.save({ |
| | 'latents': self.latents, |
| | 'labels': self.labels, |
| | 'text_sequence': self.text_sequence, |
| | 'text_pooled': self.text_pooled, |
| | 'text_dim': self.text_dim, |
| | }, self.cache_path) |
| | print(f" Saved: {self.cache_path}") |
| | |
| | def __len__(self): |
| | return len(self.labels) |
| | |
| | def __getitem__(self, idx): |
| | label = self.labels[idx] |
| | return ( |
| | self.latents[idx], |
| | self.text_sequence[label], |
| | self.text_pooled[label], |
| | label, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class SinusoidalEmbed(nn.Module): |
| | def __init__(self, dim: int): |
| | super().__init__() |
| | self.dim = dim |
| |
|
| | def forward(self, t: Tensor) -> Tensor: |
| | half = self.dim // 2 |
| | freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half) |
| | args = t.unsqueeze(-1) * freqs |
| | return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class FlowConfig: |
| | image_size: int = 256 |
| | num_classes: int = 10 |
| | latent_channels: int = 4 |
| | latent_size: int = 32 |
| | |
| | |
| | text_raw_dim: int = 2048 |
| | text_seq_len: int = 77 |
| | bottleneck_dim: int = 256 |
| | |
| | |
| | tower_dim: int = 256 |
| | tower_depth: int = 2 |
| | num_heads: int = 8 |
| | geometric_types: Tuple[str, ...] = ('cantor', 'beatrix', 'helix', 'simplex') |
| | |
| | |
| | conv_types: Tuple[str, ...] = ('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite') |
| | conv_spatial_size: int = 8 |
| | |
| | |
| | manifold_dim: int = 1024 |
| | num_tower_pairs: int = 16 |
| | osc_steps: int = 50 |
| | fingerprint_dim: int = 64 |
| | |
| | |
| | num_flow_steps: int = 50 |
| | sigma_min: float = 0.001 |
| | |
| | |
| | batch_size: int = 64 |
| | lr: float = 1e-4 |
| | weight_decay: float = 0.01 |
| | num_epochs: int = 100 |
| | |
| | cache_dir: str = "./cache" |
| | device: str = "cuda" |
| | output_dir: str = "./beatrix_cifar_t5" |
| | |
| | @property |
| | def latent_flat_dim(self) -> int: |
| | """Full flattened latent size: 4 Γ 32 Γ 32 = 4096""" |
| | return self.latent_channels * self.latent_size * self.latent_size |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class BeatrixFlowT5(WideRouter): |
| | """ |
| | Flow model with dual tower collectives per modality: |
| | |
| | Vision side: |
| | - Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg) |
| | - Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg) |
| | |
| | Text side (mirrored): |
| | - Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg) |
| | - Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg) |
| | |
| | All towers output opinions that combine for velocity prediction. |
| | """ |
| | |
| | def __init__(self, cfg: FlowConfig): |
| | super().__init__(name='beatrix_flow_t5', strict=False, auto_discover=False) |
| | self.objects['cfg'] = cfg |
| | |
| | |
| | |
| | |
| | self.attach('text_bottleneck_seq', nn.Sequential( |
| | nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim), |
| | nn.GELU(), |
| | nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim), |
| | )) |
| | self.attach('text_bottleneck_pool', nn.Sequential( |
| | nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim), |
| | nn.GELU(), |
| | nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim), |
| | )) |
| | |
| | |
| | |
| | |
| | vision_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types)) |
| | |
| | vision_geo_collective = build_tower_collective( |
| | configs=vision_geo_configs, |
| | dim=cfg.tower_dim, |
| | default_depth=cfg.tower_depth, |
| | num_heads=cfg.num_heads, |
| | ffn_mult=4.0, |
| | dropout=0.1, |
| | fingerprint_dim=cfg.fingerprint_dim, |
| | fusion_type='adaptive', |
| | name='vision_geo', |
| | ) |
| | self.attach('vision_geo', vision_geo_collective) |
| | |
| | |
| | |
| | |
| | vision_conv_configs = preset_conv_pos_neg(list(cfg.conv_types)) |
| | |
| | vision_conv_collective = build_conv_collective( |
| | configs=vision_conv_configs, |
| | dim=cfg.tower_dim, |
| | default_depth=cfg.tower_depth, |
| | fingerprint_dim=cfg.fingerprint_dim, |
| | spatial_size=cfg.conv_spatial_size, |
| | name='vision_conv', |
| | ) |
| | self.attach('vision_conv', vision_conv_collective) |
| | |
| | |
| | |
| | |
| | text_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types)) |
| | |
| | text_geo_collective = build_tower_collective( |
| | configs=text_geo_configs, |
| | dim=cfg.tower_dim, |
| | default_depth=cfg.tower_depth, |
| | num_heads=cfg.num_heads, |
| | ffn_mult=4.0, |
| | dropout=0.1, |
| | fingerprint_dim=cfg.fingerprint_dim, |
| | fusion_type='adaptive', |
| | name='text_geo', |
| | ) |
| | self.attach('text_geo', text_geo_collective) |
| | |
| | |
| | |
| | |
| | text_conv_configs = preset_conv_pos_neg(list(cfg.conv_types)) |
| | |
| | text_conv_collective = build_conv_collective( |
| | configs=text_conv_configs, |
| | dim=cfg.tower_dim, |
| | default_depth=cfg.tower_depth, |
| | fingerprint_dim=cfg.fingerprint_dim, |
| | spatial_size=cfg.conv_spatial_size, |
| | name='text_conv', |
| | ) |
| | self.attach('text_conv', text_conv_collective) |
| | |
| | |
| | |
| | |
| | |
| | patch_size = 4 |
| | num_patches = (cfg.latent_size // patch_size) ** 2 |
| | patch_dim = cfg.latent_channels * patch_size * patch_size |
| | |
| | self.attach('patch_proj', nn.Linear(patch_dim, cfg.tower_dim)) |
| | self.patch_pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.tower_dim) * 0.02) |
| | self.objects['patch_size'] = patch_size |
| | self.objects['num_patches'] = num_patches |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | num_geo_towers = len(vision_geo_configs) |
| | num_conv_towers = len(vision_conv_configs) |
| | total_towers = (num_geo_towers + num_conv_towers) * 2 |
| | |
| | oscillator = BeatrixOscillator( |
| | name='oscillator', |
| | manifold_dim=cfg.manifold_dim, |
| | tower_dim=cfg.tower_dim, |
| | num_tower_pairs=total_towers // 2, |
| | num_theta_probes=4, |
| | fingerprint_dim=cfg.fingerprint_dim, |
| | kappa_schedule=ScheduleType.TESLA_369, |
| | use_intrinsic_tension=True, |
| | ) |
| | self.attach('oscillator', oscillator) |
| | |
| | |
| | |
| | |
| | |
| | time_embed = nn.Sequential( |
| | SinusoidalEmbed(256), |
| | nn.Linear(256, cfg.tower_dim), |
| | nn.GELU(), |
| | nn.Linear(cfg.tower_dim, cfg.tower_dim), |
| | ) |
| | self.attach('time_embed', time_embed) |
| | |
| | |
| | self.attach('text_to_ref', nn.Sequential( |
| | nn.Linear(cfg.bottleneck_dim, cfg.manifold_dim), |
| | nn.GELU(), |
| | nn.Linear(cfg.manifold_dim, cfg.manifold_dim), |
| | )) |
| | |
| | |
| | self.attach('time_to_ref', nn.Linear(cfg.tower_dim, cfg.manifold_dim)) |
| | |
| | |
| | |
| | |
| | self.attach('latent_down', nn.Linear(cfg.latent_flat_dim, cfg.manifold_dim)) |
| | self.attach('latent_up', nn.Linear(cfg.manifold_dim, cfg.latent_flat_dim)) |
| | |
| | |
| | self.velocity_mix = nn.Parameter(torch.tensor(0.5)) |
| | |
| | def patchify(self, z: Tensor) -> Tensor: |
| | """[B, 4, 32, 32] -> [B, num_patches, tower_dim]""" |
| | B, C, H, W = z.shape |
| | p = self.objects['patch_size'] |
| | |
| | z = z.unfold(2, p, p).unfold(3, p, p) |
| | z = z.permute(0, 2, 3, 1, 4, 5).contiguous() |
| | z = z.view(B, -1, C * p * p) |
| | |
| | return self['patch_proj'](z) + self.patch_pos_embed |
| | |
| | def get_tower_outputs(self, z: Tensor, text_seq: Tensor) -> List[Tensor]: |
| | """ |
| | Run all four tower collectives. |
| | Returns list of tower opinions [B, tower_dim] (32 total). |
| | """ |
| | patches = self.patchify(z) |
| | text_bottlenecked = self['text_bottleneck_seq'](text_seq) |
| | |
| | |
| | vision_geo = self['vision_geo'](patches) |
| | vision_conv_fused, vision_conv_ops = self['vision_conv'](patches) |
| | text_geo = self['text_geo'](text_bottlenecked) |
| | text_conv_fused, text_conv_ops = self['text_conv'](text_bottlenecked) |
| | |
| | |
| | return ( |
| | [op.opinion for op in vision_geo.opinions.values()] + |
| | list(vision_conv_ops.values()) + |
| | [op.opinion for op in text_geo.opinions.values()] + |
| | list(text_conv_ops.values()) |
| | ) |
| | |
| | def forward( |
| | self, |
| | z_0: Tensor, |
| | text_seq: Tensor, |
| | text_pooled: Tensor, |
| | labels: Tensor, |
| | t: Optional[Tensor] = None, |
| | ) -> Dict[str, Tensor]: |
| | """Training forward - single step velocity prediction.""" |
| | cfg = self.objects['cfg'] |
| | B = z_0.shape[0] |
| | device = z_0.device |
| | |
| | if t is None: |
| | t = torch.rand(B, device=device) |
| | |
| | |
| | z_0_flat = z_0.flatten(1) |
| | |
| | |
| | eps = torch.randn_like(z_0) |
| | eps_flat = eps.flatten(1) |
| | t_exp = t.view(B, 1, 1, 1) |
| | z_t = (1 - t_exp) * z_0 + t_exp * eps |
| | z_t_flat = z_t.flatten(1) |
| | |
| | |
| | v_target = eps_flat - z_0_flat |
| | |
| | |
| | z_t_proj = self['latent_down'](z_t_flat) |
| | |
| | |
| | text_pooled_bn = self['text_bottleneck_pool'](text_pooled) |
| | |
| | |
| | time_emb = self['time_embed'](t) |
| | x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb) |
| | |
| | |
| | tower_outputs = self.get_tower_outputs(z_t, text_seq) |
| | |
| | |
| | osc = self['oscillator'] |
| | tower_force, _ = osc.force_generator(z_t_proj, tower_outputs, state_fingerprint=None) |
| | spring_force = x_ref - z_t_proj |
| | |
| | |
| | tau = torch.sigmoid(self.velocity_mix) |
| | v_pred_proj = (1 - tau) * spring_force + tau * tower_force |
| | |
| | |
| | v_pred = self['latent_up'](v_pred_proj) |
| | |
| | loss = F.mse_loss(v_pred, v_target) |
| | |
| | return {'loss': loss, 'tau': tau.detach()} |
| | |
| | @torch.no_grad() |
| | def sample( |
| | self, |
| | text_seq: Tensor, |
| | text_pooled: Tensor, |
| | vae: SD15VAE, |
| | num_steps: Optional[int] = None, |
| | ) -> Tensor: |
| | """Generate samples from text conditioning.""" |
| | cfg = self.objects['cfg'] |
| | B = text_seq.shape[0] |
| | device = text_seq.device |
| | num_steps = num_steps or cfg.num_flow_steps |
| | |
| | |
| | text_pooled_bn = self['text_bottleneck_pool'](text_pooled) |
| | |
| | |
| | z = torch.randn(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size, device=device) |
| | |
| | dt = 1.0 / num_steps |
| | |
| | for step in range(num_steps): |
| | t_val = 1.0 - step * dt |
| | t = torch.full((B,), t_val, device=device) |
| | |
| | time_emb = self['time_embed'](t) |
| | x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb) |
| | |
| | z_flat = z.flatten(1) |
| | |
| | |
| | z_proj = self['latent_down'](z_flat) |
| | |
| | tower_outputs = self.get_tower_outputs(z, text_seq) |
| | |
| | osc = self['oscillator'] |
| | tower_force, _ = osc.force_generator(z_proj, tower_outputs, state_fingerprint=None) |
| | spring_force = x_ref - z_proj |
| | |
| | tau = torch.sigmoid(self.velocity_mix) |
| | v_pred_proj = (1 - tau) * spring_force + tau * tower_force |
| | |
| | |
| | v_pred = self['latent_up'](v_pred_proj) |
| | z_flat = z_flat - dt * v_pred |
| | z = z_flat.view(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size) |
| | |
| | return vae.decode(z) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Trainer: |
| | def __init__(self, cfg: FlowConfig): |
| | self.cfg = cfg |
| | self.device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") |
| | self.output_dir = Path(cfg.output_dir) |
| | self.output_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | if torch.cuda.is_available(): |
| | torch.backends.cudnn.benchmark = True |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | |
| | self.scaler = torch.amp.GradScaler('cuda') |
| | |
| | |
| | print("\n=== Building Cached Datasets ===") |
| | self.train_dataset = CachedCIFAR10T5(train=True, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device) |
| | self.val_dataset = CachedCIFAR10T5(train=False, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device) |
| | |
| | |
| | cfg.text_raw_dim = self.train_dataset.text_dim |
| | print(f"T5 raw dimension: {cfg.text_raw_dim} β bottleneck: {cfg.bottleneck_dim}") |
| | |
| | self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) |
| | self.val_loader = DataLoader(self.val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0, pin_memory=True) |
| | |
| | |
| | self.text_sequence = self.train_dataset.text_sequence.to(self.device) |
| | self.text_pooled = self.train_dataset.text_pooled.to(self.device) |
| | |
| | |
| | print("\n=== Building Model (Vision + Text Towers) ===") |
| | self.model = BeatrixFlowT5(cfg).to(self.device) |
| | |
| | |
| | if hasattr(torch, 'compile'): |
| | print("Compiling with WideRouter.prepare_and_compile()...") |
| | self.model = self.model.prepare_and_compile( |
| | mode="reduce-overhead", |
| | fullgraph=False, |
| | ) |
| | |
| | num_params = sum(p.numel() for p in self.model.parameters()) |
| | print(f"Trainable parameters: {num_params:,}") |
| | |
| | self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
| | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=cfg.num_epochs * len(self.train_loader)) |
| | |
| | |
| | self.start_epoch = 0 |
| | self.hf_repo = "AbstractPhil/beatrix-diffusion-proto" |
| | self._load_latest_checkpoint() |
| | |
| | self._vae = None |
| | |
| | |
| | self._setup_hf_repo() |
| | |
| | def _setup_hf_repo(self): |
| | """Create HF repo if needed and save initial config.""" |
| | try: |
| | self.hf_api = HfApi() |
| | create_repo(self.hf_repo, exist_ok=True, repo_type="model") |
| | print(f"HF repo: {self.hf_repo}") |
| | |
| | |
| | config_dict = { |
| | 'image_size': self.cfg.image_size, |
| | 'num_classes': self.cfg.num_classes, |
| | 'latent_channels': self.cfg.latent_channels, |
| | 'latent_size': self.cfg.latent_size, |
| | 'text_raw_dim': self.cfg.text_raw_dim, |
| | 'bottleneck_dim': self.cfg.bottleneck_dim, |
| | 'tower_dim': self.cfg.tower_dim, |
| | 'tower_depth': self.cfg.tower_depth, |
| | 'num_heads': self.cfg.num_heads, |
| | 'geometric_types': self.cfg.geometric_types, |
| | 'conv_types': self.cfg.conv_types, |
| | 'conv_spatial_size': self.cfg.conv_spatial_size, |
| | 'manifold_dim': self.cfg.manifold_dim, |
| | 'fingerprint_dim': self.cfg.fingerprint_dim, |
| | 'num_flow_steps': self.cfg.num_flow_steps, |
| | } |
| | config_path = self.output_dir / "config.json" |
| | with open(config_path, 'w') as f: |
| | json.dump(config_dict, f, indent=2) |
| | |
| | upload_file( |
| | path_or_fileobj=str(config_path), |
| | path_in_repo="config.json", |
| | repo_id=self.hf_repo, |
| | ) |
| | except Exception as e: |
| | print(f"HF setup warning: {e}") |
| | self.hf_api = None |
| | |
| | def _upload_to_hf(self, epoch: int, sample_path: Path, metrics: dict = None): |
| | """Upload checkpoint, samples, and metrics to HuggingFace.""" |
| | if self.hf_api is None: |
| | return |
| | |
| | try: |
| | |
| | ckpt_path = self.output_dir / "ckpt_latest.pt" |
| | if ckpt_path.exists(): |
| | upload_file( |
| | path_or_fileobj=str(ckpt_path), |
| | path_in_repo="ckpt_latest.pt", |
| | repo_id=self.hf_repo, |
| | ) |
| | |
| | |
| | if sample_path.exists(): |
| | upload_file( |
| | path_or_fileobj=str(sample_path), |
| | path_in_repo=f"samples/epoch_{epoch:03d}.png", |
| | repo_id=self.hf_repo, |
| | ) |
| | |
| | upload_file( |
| | path_or_fileobj=str(sample_path), |
| | path_in_repo="samples/latest.png", |
| | repo_id=self.hf_repo, |
| | ) |
| | |
| | |
| | if metrics: |
| | metrics_path = self.output_dir / "metrics.jsonl" |
| | with open(metrics_path, 'a') as f: |
| | f.write(json.dumps({'epoch': epoch, **metrics}) + '\n') |
| | upload_file( |
| | path_or_fileobj=str(metrics_path), |
| | path_in_repo="metrics.jsonl", |
| | repo_id=self.hf_repo, |
| | ) |
| | |
| | print(f" β Uploaded to HF") |
| | except Exception as e: |
| | print(f" β HF upload failed: {e}") |
| | |
| | def _load_latest_checkpoint(self): |
| | """Load most recent checkpoint if available (local or HF).""" |
| | latest_path = self.output_dir / "ckpt_latest.pt" |
| | |
| | |
| | if latest_path.exists(): |
| | print(f"Resuming from local ckpt_latest.pt...") |
| | ckpt = torch.load(latest_path, weights_only=False) |
| | else: |
| | |
| | ckpts = sorted(self.output_dir.glob("ckpt_epoch*.pt")) |
| | if ckpts: |
| | latest_path = ckpts[-1] |
| | print(f"Resuming from {latest_path.name}...") |
| | ckpt = torch.load(latest_path, weights_only=False) |
| | else: |
| | |
| | try: |
| | from huggingface_hub import hf_hub_download |
| | print(f"Checking HF for checkpoint...") |
| | hf_path = hf_hub_download( |
| | repo_id=self.hf_repo, |
| | filename="ckpt_latest.pt", |
| | local_dir=str(self.output_dir), |
| | ) |
| | print(f"Downloaded checkpoint from HF") |
| | ckpt = torch.load(hf_path, weights_only=False) |
| | except Exception as e: |
| | print(f"No checkpoint found (local or HF): {e}") |
| | return |
| | |
| | self.model.load_state_dict(ckpt['model']) |
| | self.optimizer.load_state_dict(ckpt['optimizer']) |
| | self.scheduler.load_state_dict(ckpt['scheduler']) |
| | self.start_epoch = ckpt['epoch'] |
| | print(f" Resumed at epoch {self.start_epoch}") |
| | |
| | def _load_vae(self): |
| | """Load VAE for sampling (temporary).""" |
| | print("Loading VAE for sampling...") |
| | return SD15VAE(freeze=True).to(self.device) |
| | |
| | def _unload_vae(self, vae): |
| | """Unload VAE after sampling.""" |
| | del vae |
| | torch.cuda.empty_cache() |
| | |
| | def train_epoch(self, epoch: int) -> Dict[str, float]: |
| | self.model.train() |
| | total_loss, total_tau, n = 0.0, 0.0, 0 |
| | |
| | pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.cfg.num_epochs}", leave=False) |
| | for latents, text_seq, text_pooled, labels in pbar: |
| | latents = latents.to(self.device) |
| | text_seq = text_seq.to(self.device) |
| | text_pooled = text_pooled.to(self.device) |
| | labels = labels.to(self.device) |
| | |
| | with torch.amp.autocast('cuda'): |
| | out = self.model(latents, text_seq, text_pooled, labels) |
| | loss = out['loss'] |
| | |
| | self.optimizer.zero_grad() |
| | self.scaler.scale(loss).backward() |
| | self.scaler.unscale_(self.optimizer) |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| | self.scaler.step(self.optimizer) |
| | self.scaler.update() |
| | self.scheduler.step() |
| | |
| | total_loss += loss.item() |
| | total_tau += out['tau'].item() |
| | n += 1 |
| | |
| | pbar.set_postfix(loss=f"{loss.item():.4f}", Ο=f"{out['tau'].item():.2f}") |
| | |
| | return {'loss': total_loss / n, 'tau': total_tau / n} |
| | |
| | @torch.no_grad() |
| | def validate(self) -> Dict[str, float]: |
| | self.model.eval() |
| | total_loss, n = 0.0, 0 |
| | |
| | for latents, text_seq, text_pooled, labels in self.val_loader: |
| | latents = latents.to(self.device) |
| | text_seq = text_seq.to(self.device) |
| | text_pooled = text_pooled.to(self.device) |
| | labels = labels.to(self.device) |
| | |
| | with torch.amp.autocast('cuda'): |
| | out = self.model(latents, text_seq, text_pooled, labels) |
| | total_loss += out['loss'].item() |
| | n += 1 |
| | |
| | return {'val_loss': total_loss / n} |
| | |
| | @torch.no_grad() |
| | def sample_images(self, n_per_class: int = 10) -> Tensor: |
| | """Generate samples for each class (memory-efficient batched).""" |
| | self.model.eval() |
| | torch.cuda.empty_cache() |
| | |
| | |
| | vae = self._load_vae() |
| | |
| | all_samples = [] |
| | batch_size = 10 |
| | |
| | for class_idx in range(10): |
| | |
| | for batch_start in range(0, n_per_class, batch_size): |
| | batch_n = min(batch_size, n_per_class - batch_start) |
| | |
| | text_seq = self.text_sequence[class_idx:class_idx+1].expand(batch_n, -1, -1) |
| | text_pooled = self.text_pooled[class_idx:class_idx+1].expand(batch_n, -1) |
| | |
| | with torch.amp.autocast('cuda'): |
| | samples = self.model.sample(text_seq, text_pooled, vae) |
| | |
| | all_samples.append(samples.cpu()) |
| | |
| | |
| | self._unload_vae(vae) |
| | |
| | samples = torch.cat(all_samples, dim=0).to(self.device) |
| | return ((samples + 1) / 2).clamp(0, 1) |
| | |
| | def save_checkpoint(self, epoch: int, milestone: bool = False): |
| | ckpt = { |
| | 'epoch': epoch, |
| | 'model': self.model.state_dict(), |
| | 'optimizer': self.optimizer.state_dict(), |
| | 'scheduler': self.scheduler.state_dict(), |
| | } |
| | |
| | torch.save(ckpt, self.output_dir / "ckpt_latest.pt") |
| | |
| | if milestone: |
| | torch.save(ckpt, self.output_dir / f"ckpt_epoch{epoch:03d}.pt") |
| | |
| | def train(self): |
| | num_geo = len(self.cfg.geometric_types) * 2 |
| | num_conv = len(self.cfg.conv_types) * 2 |
| | total_towers = (num_geo + num_conv) * 2 |
| | |
| | print(f"\n{'='*60}") |
| | print("BEATRIX FLOW - Dual Geometric + Conv Towers (Bottlenecked)") |
| | print(f"{'='*60}") |
| | print(f"Device: {self.device}") |
| | print(f"Geometric towers: {self.cfg.geometric_types} (pos/neg)") |
| | print(f"Conv towers: {self.cfg.conv_types} (pos/neg)") |
| | print(f"Tower dim: {self.cfg.tower_dim}") |
| | print(f"T5 raw β bottleneck: {self.cfg.text_raw_dim} β {self.cfg.bottleneck_dim}") |
| | print(f"Latent β manifold: {self.cfg.latent_flat_dim} β {self.cfg.manifold_dim}") |
| | print(f"Total towers: {total_towers}") |
| | print(f"Batch size: {self.cfg.batch_size}") |
| | print(f"Epochs: {self.start_epoch}/{self.cfg.num_epochs}") |
| | print(f"{'='*60}\n") |
| | |
| | for epoch in range(self.start_epoch, self.cfg.num_epochs): |
| | train_metrics = self.train_epoch(epoch) |
| | val_metrics = self.validate() |
| | |
| | lr = self.scheduler.get_last_lr()[0] |
| | print(f"Epoch {epoch+1:3d} β loss={train_metrics['loss']:.4f} β val={val_metrics['val_loss']:.4f} β Ο={train_metrics['tau']:.2f} β lr={lr:.2e}") |
| | |
| | |
| | samples = self.sample_images(10) |
| | grid = make_grid(samples, nrow=10, padding=2) |
| | sample_path = self.output_dir / f"samples_epoch{epoch+1:03d}.png" |
| | save_image(grid, sample_path) |
| | print(f" β Saved samples") |
| | |
| | |
| | self.save_checkpoint(epoch + 1, milestone=((epoch + 1) % 10 == 0)) |
| | |
| | |
| | metrics = { |
| | 'loss': train_metrics['loss'], |
| | 'val_loss': val_metrics['val_loss'], |
| | 'tau': train_metrics['tau'], |
| | 'lr': lr, |
| | } |
| | self._upload_to_hf(epoch + 1, sample_path, metrics) |
| | |
| | samples = self.sample_images(10) |
| | grid = make_grid(samples, nrow=10, padding=2) |
| | final_path = self.output_dir / "samples_final.png" |
| | save_image(grid, final_path) |
| | self.save_checkpoint(self.cfg.num_epochs, milestone=True) |
| | self._upload_to_hf(self.cfg.num_epochs, final_path) |
| | print(f"\nTraining complete!") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | |
| | cfg = FlowConfig( |
| | image_size=256, |
| | tower_dim=256, |
| | tower_depth=2, |
| | num_heads=8, |
| | geometric_types=('cantor', 'beatrix'), |
| | conv_types=('wide_resnet', 'squeeze_excite'), |
| | conv_spatial_size=8, |
| | bottleneck_dim=256, |
| | manifold_dim=512, |
| | batch_size=64, |
| | num_epochs=100, |
| | cache_dir="./cache", |
| | output_dir="./beatrix_cifar_t5", |
| | ) |
| | |
| | trainer = Trainer(cfg) |
| | trainer.train() |
| |
|
| |
|
| | def main_full(): |
| | """Full 32-tower configuration.""" |
| | cfg = FlowConfig( |
| | image_size=256, |
| | tower_dim=256, |
| | tower_depth=2, |
| | num_heads=8, |
| | geometric_types=('cantor', 'beatrix', 'helix', 'simplex'), |
| | conv_types=('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite'), |
| | conv_spatial_size=8, |
| | bottleneck_dim=256, |
| | manifold_dim=1024, |
| | batch_size=64, |
| | num_epochs=100, |
| | cache_dir="./cache", |
| | output_dir="./beatrix_cifar_t5", |
| | ) |
| | |
| | trainer = Trainer(cfg) |
| | trainer.train() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |