| | |
| | """ |
| | K-Simplex Language Model - Inference Script |
| | |
| | Loads a trained k-simplex LLM checkpoint and generates text using |
| | geometrically-validated autoregressive sampling. |
| | |
| | Usage: |
| | python inference.py --checkpoint checkpoint_epoch_008.pt --prompt "ROMEO: " |
| | python inference.py --repo AbstractPhil/ksimplex-llm-prototype --prompt "To be or not" |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import tiktoken |
| | from pathlib import Path |
| | from huggingface_hub import hf_hub_download |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def factorial(n: int) -> int: |
| | return math.factorial(n) |
| |
|
| |
|
| | def cayley_menger_volume_squared(vertices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Compute squared volume via Cayley-Menger determinant. |
| | |
| | Args: |
| | vertices: [*, nv, edim] vertex coordinates |
| | |
| | Returns: |
| | d2: [*, n_pairs] squared distances |
| | vol2: [*] squared volume |
| | """ |
| | nv = vertices.shape[-2] |
| | k = nv - 1 |
| | |
| | |
| | diff = vertices.unsqueeze(-2) - vertices.unsqueeze(-3) |
| | d2_matrix = (diff ** 2).sum(-1) |
| | |
| | |
| | idx = torch.triu_indices(nv, nv, offset=1) |
| | d2 = d2_matrix[..., idx[0], idx[1]] |
| | |
| | |
| | batch_shape = vertices.shape[:-2] |
| | size = nv + 1 |
| | cm = torch.zeros(*batch_shape, size, size, device=vertices.device, dtype=vertices.dtype) |
| | |
| | |
| | cm[..., 0, 1:] = 1.0 |
| | cm[..., 1:, 0] = 1.0 |
| | |
| | |
| | cm[..., 1:, 1:] = d2_matrix |
| | |
| | |
| | |
| | |
| | det = torch.linalg.det(cm) |
| | |
| | |
| | sign = (-1) ** (k + 1) |
| | denom = (2 ** k) * (factorial(k) ** 2) |
| | vol2 = sign * det / denom |
| | |
| | return d2, vol2 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class SimplexTemplate(nn.Module): |
| | """Generates regular simplex template vertices.""" |
| | |
| | def __init__(self, k: int, edim: int, scale: float = 1.0): |
| | super().__init__() |
| | self.k = k |
| | self.nv = k + 1 |
| | self.edim = edim |
| | |
| | |
| | vertices = torch.zeros(self.nv, edim) |
| | for i in range(self.nv): |
| | angle = 2 * math.pi * i / self.nv |
| | vertices[i, 0] = scale * math.cos(angle) |
| | if edim > 1: |
| | vertices[i, 1] = scale * math.sin(angle) |
| | if edim > 2: |
| | vertices[i, 2] = scale * 0.3 * math.cos(angle * 2) |
| | for d in range(3, edim): |
| | vertices[i, d] = scale * 0.1 * math.sin(angle * (d + 1)) |
| | |
| | self.register_buffer('template', vertices) |
| | |
| | def forward(self) -> torch.Tensor: |
| | return self.template |
| |
|
| |
|
| | class KSimplexChannel(nn.Module): |
| | """Single k-simplex channel with geometric validation.""" |
| | |
| | def __init__(self, k: int, edim: int, hidden: int, feat_dim: int, base_deform: float = 0.05): |
| | super().__init__() |
| | self.k = k |
| | self.nv = k + 1 |
| | self.edim = edim |
| | self.feat_dim = feat_dim |
| | self.base_deform = base_deform |
| | |
| | |
| | self.template = SimplexTemplate(k, edim) |
| | |
| | |
| | self._to_coords = nn.Linear(hidden, self.nv * edim) |
| | self._to_feats = nn.Linear(hidden, self.nv * feat_dim) |
| | |
| | |
| | n_pairs = (self.nv * (self.nv - 1)) // 2 |
| | self.geo_dim = n_pairs + 1 |
| | |
| | |
| | self._geo_gate = nn.Sequential( |
| | nn.Linear(self.geo_dim, feat_dim), |
| | nn.Sigmoid() |
| | ) |
| | |
| | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Args: |
| | x: [*, hidden] |
| | |
| | Returns: |
| | out: [*, feat_dim + geo_dim] gated features + geometry |
| | vol2: [*] squared volume for validity loss |
| | mean_d2: [*] mean squared distance |
| | """ |
| | |
| | coords = self._to_coords(x).unflatten(-1, (self.nv, self.edim)) |
| | verts = self.template() + self.base_deform * coords |
| | |
| | |
| | vert_feats = self._to_feats(x).unflatten(-1, (self.nv, self.feat_dim)) |
| | |
| | |
| | d2, vol2 = cayley_menger_volume_squared(verts) |
| | |
| | |
| | geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1) |
| | |
| | |
| | gate = self._geo_gate(geo) |
| | validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1) |
| | |
| | |
| | feat_agg = vert_feats.mean(dim=-2) * gate * validity |
| | |
| | |
| | out = torch.cat([feat_agg, geo], dim=-1) |
| | |
| | return out, vol2, d2.mean(dim=-1) |
| |
|
| |
|
| | class TokenToKChannels(nn.Module): |
| | """Project token embeddings to k-simplex channels.""" |
| | |
| | def __init__(self, embed_dim: int, hidden: int, depth: int, edim: int, feat_dim: int): |
| | super().__init__() |
| | self.depth = depth |
| | |
| | self._proj = nn.Linear(embed_dim, hidden) |
| | self._channels = nn.ModuleList([ |
| | KSimplexChannel(k=k+1, edim=edim, hidden=hidden, feat_dim=feat_dim) |
| | for k in range(depth) |
| | ]) |
| | |
| | |
| | self.out_dims = [ch.feat_dim + ch.geo_dim for ch in self._channels] |
| | self.max_dim = max(self.out_dims) |
| | |
| | |
| | self._pads = nn.ModuleList([ |
| | nn.Linear(d, self.max_dim) if d != self.max_dim else nn.Identity() |
| | for d in self.out_dims |
| | ]) |
| | |
| | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]: |
| | """ |
| | Args: |
| | x: [B, T, embed_dim] |
| | |
| | Returns: |
| | out: [B, T, K, max_dim] |
| | vol2_list: list of [B, T] per k |
| | d2_list: list of [B, T] per k |
| | """ |
| | h = self._proj(x) |
| | |
| | outputs = [] |
| | vol2_list = [] |
| | d2_list = [] |
| | |
| | for ch, pad in zip(self._channels, self._pads): |
| | out, vol2, d2 = ch(h) |
| | outputs.append(pad(out)) |
| | vol2_list.append(vol2) |
| | d2_list.append(d2) |
| | |
| | |
| | out = torch.stack(outputs, dim=-2) |
| | |
| | return out, vol2_list, d2_list |
| |
|
| |
|
| | class KChannelCrossAttention(nn.Module): |
| | """Cross-attention between k-levels at each position.""" |
| | |
| | def __init__(self, dim: int, num_heads: int = 4, dropout: float = 0.1): |
| | super().__init__() |
| | self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True) |
| | self.norm = nn.LayerNorm(dim) |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: [B, T, K, D] |
| | Returns: |
| | [B, T, K, D] |
| | """ |
| | B, T, K, D = x.shape |
| | |
| | |
| | x_flat = x.view(B * T, K, D) |
| | |
| | |
| | attn_out, _ = self.attn(x_flat, x_flat, x_flat) |
| | |
| | |
| | out = self.norm(x_flat + attn_out) |
| | |
| | return out.view(B, T, K, D) |
| |
|
| |
|
| | class CausalSequenceAttention(nn.Module): |
| | """Causal attention across sequence positions.""" |
| | |
| | def __init__(self, dim: int, num_heads: int, max_seq_len: int, dropout: float = 0.1): |
| | super().__init__() |
| | self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True) |
| | self.norm = nn.LayerNorm(dim) |
| | |
| | |
| | mask = torch.tril(torch.ones(max_seq_len, max_seq_len)).bool() |
| | self.register_buffer('_causal_mask', mask) |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: [B, T, K, D] |
| | Returns: |
| | [B, T, K, D] |
| | """ |
| | B, T, K, D = x.shape |
| | |
| | |
| | x_flat = x.view(B, T, K * D) |
| | |
| | |
| | mask = self._causal_mask[:T, :T] |
| | attn_mask = ~mask |
| | |
| | |
| | attn_out, _ = self.attn( |
| | x_flat, x_flat, x_flat, |
| | attn_mask=attn_mask.float().masked_fill(attn_mask, float('-inf')) |
| | ) |
| | |
| | |
| | out = self.norm(x_flat + attn_out) |
| | |
| | return out.view(B, T, K, D) |
| |
|
| |
|
| | class GeoBlock(nn.Module): |
| | """Geometric block: k-channel attention + causal sequence attention + MLP.""" |
| | |
| | def __init__(self, dim: int, num_heads: int, max_seq_len: int, depth: int, dropout: float = 0.1): |
| | super().__init__() |
| | self.k_attn = KChannelCrossAttention(dim, num_heads=4, dropout=dropout) |
| | self.seq_attn = CausalSequenceAttention(dim, num_heads, max_seq_len, dropout) |
| | |
| | self.mlp = nn.Sequential( |
| | nn.Linear(dim * depth, dim * depth * 4), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(dim * depth * 4, dim * depth), |
| | nn.Dropout(dropout), |
| | ) |
| | self.mlp_norm = nn.LayerNorm(dim * depth) |
| | self.depth = depth |
| | |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: [B, T, K, D] |
| | Returns: |
| | [B, T, K, D] |
| | """ |
| | |
| | x = self.k_attn(x) |
| | |
| | |
| | x = self.seq_attn(x) |
| | |
| | |
| | B, T, K, D = x.shape |
| | x_flat = x.view(B, T, K * D) |
| | x_flat = self.mlp_norm(x_flat + self.mlp(x_flat)) |
| | |
| | return x_flat.view(B, T, K, D) |
| |
|
| |
|
| | class KSimplexLM(nn.Module): |
| | """K-Simplex Language Model.""" |
| | |
| | def __init__( |
| | self, |
| | vocab_size: int = 50257, |
| | max_seq_len: int = 256, |
| | embed_dim: int = 384, |
| | depth: int = 4, |
| | edim: int = 16, |
| | feat_dim: int = 96, |
| | hidden: int = 384, |
| | num_heads: int = 8, |
| | num_blocks: int = 8, |
| | dropout: float = 0.1, |
| | ): |
| | super().__init__() |
| | self.vocab_size = vocab_size |
| | self.max_seq_len = max_seq_len |
| | self.depth = depth |
| | |
| | |
| | self.embed = nn.Embedding(vocab_size, embed_dim) |
| | self.pos_embed = nn.Embedding(max_seq_len, embed_dim) |
| | self.embed_drop = nn.Dropout(dropout) |
| | |
| | |
| | self.to_k_channels = TokenToKChannels(embed_dim, hidden, depth, edim, feat_dim) |
| | |
| | |
| | k_dim = self.to_k_channels.max_dim |
| | self.blocks = nn.ModuleList([ |
| | GeoBlock(k_dim, num_heads, max_seq_len, depth, dropout) |
| | for _ in range(num_blocks) |
| | ]) |
| | |
| | |
| | self.ln_f = nn.LayerNorm(k_dim * depth) |
| | self.lm_head = nn.Linear(k_dim * depth, vocab_size, bias=False) |
| | |
| | |
| | |
| | |
| | self._init_weights() |
| | |
| | def _init_weights(self): |
| | for m in self.modules(): |
| | if isinstance(m, nn.Linear): |
| | nn.init.normal_(m.weight, std=0.02) |
| | if m.bias is not None: |
| | nn.init.zeros_(m.bias) |
| | elif isinstance(m, nn.Embedding): |
| | nn.init.normal_(m.weight, std=0.02) |
| | |
| | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]: |
| | """ |
| | Args: |
| | x: [B, T] token indices |
| | |
| | Returns: |
| | logits: [B, T, vocab_size] |
| | geo_info: dict with vol2, d2 per k-level |
| | """ |
| | B, T = x.shape |
| | |
| | |
| | pos = torch.arange(T, device=x.device).unsqueeze(0) |
| | h = self.embed(x) + self.pos_embed(pos) |
| | h = self.embed_drop(h) |
| | |
| | |
| | h, vol2_list, d2_list = self.to_k_channels(h) |
| | |
| | |
| | for block in self.blocks: |
| | h = block(h) |
| | |
| | |
| | h_flat = h.view(B, T, -1) |
| | h_flat = self.ln_f(h_flat) |
| | logits = self.lm_head(h_flat) |
| | |
| | geo_info = { |
| | 'vol2': vol2_list, |
| | 'd2': d2_list, |
| | } |
| | |
| | return logits, geo_info |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_model( |
| | checkpoint_path: str = None, |
| | repo_id: str = None, |
| | device: str = None, |
| | ) -> tuple[KSimplexLM, tiktoken.Encoding]: |
| | """ |
| | Load model from checkpoint or HuggingFace Hub. |
| | |
| | Args: |
| | checkpoint_path: Local path to checkpoint |
| | repo_id: HuggingFace repo ID (e.g., "AbstractPhil/ksimplex-llm-prototype") |
| | device: Device to load to |
| | |
| | Returns: |
| | model: KSimplexLM |
| | tokenizer: tiktoken encoding |
| | """ |
| | if device is None: |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | |
| | |
| | if repo_id: |
| | checkpoint_path = hf_hub_download(repo_id, "checkpoint_latest.pt") |
| | config_path = hf_hub_download(repo_id, "config.json") |
| | with open(config_path) as f: |
| | config = json.load(f) |
| | elif checkpoint_path: |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | config = checkpoint.get('config', {}).get('model', {}) |
| | else: |
| | raise ValueError("Must provide checkpoint_path or repo_id") |
| | |
| | |
| | model = KSimplexLM( |
| | vocab_size=config.get('vocab_size', 50257), |
| | max_seq_len=config.get('max_seq_len', 256), |
| | embed_dim=config.get('embed_dim', 384), |
| | depth=config.get('depth', 4), |
| | edim=config.get('edim', 16), |
| | feat_dim=config.get('feat_dim', 96), |
| | hidden=config.get('hidden', 384), |
| | num_heads=config.get('num_heads', 8), |
| | num_blocks=config.get('num_blocks', 8), |
| | dropout=0.0, |
| | ) |
| | |
| | |
| | if repo_id: |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | state_dict = checkpoint.get('model_state_dict', checkpoint) |
| | model.load_state_dict(state_dict) |
| | |
| | model.to(device) |
| | model.eval() |
| | |
| | |
| | tokenizer = tiktoken.get_encoding("gpt2") |
| | |
| | return model, tokenizer |
| |
|
| |
|
| | @torch.no_grad() |
| | def generate( |
| | model: KSimplexLM, |
| | tokenizer: tiktoken.Encoding, |
| | prompt: str, |
| | max_tokens: int = 100, |
| | temperature: float = 0.8, |
| | top_k: int = 50, |
| | top_p: float = 0.9, |
| | device: str = None, |
| | ) -> str: |
| | """ |
| | Generate text from prompt. |
| | |
| | Args: |
| | model: KSimplexLM model |
| | tokenizer: tiktoken encoding |
| | prompt: Input text prompt |
| | max_tokens: Maximum tokens to generate |
| | temperature: Sampling temperature |
| | top_k: Top-k sampling |
| | top_p: Nucleus sampling threshold |
| | device: Device |
| | |
| | Returns: |
| | Generated text including prompt |
| | """ |
| | if device is None: |
| | device = next(model.parameters()).device |
| | |
| | |
| | tokens = tokenizer.encode(prompt) |
| | tokens = torch.tensor([tokens], dtype=torch.long, device=device) |
| | |
| | |
| | for _ in range(max_tokens): |
| | |
| | if tokens.shape[1] > model.max_seq_len: |
| | tokens = tokens[:, -model.max_seq_len:] |
| | |
| | |
| | logits, geo_info = model(tokens) |
| | logits = logits[:, -1, :] |
| | |
| | |
| | logits = logits / temperature |
| | |
| | |
| | if top_k > 0: |
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = float('-inf') |
| | |
| | |
| | if top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | |
| | |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| | |
| | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| | logits[indices_to_remove] = float('-inf') |
| | |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | |
| | |
| | tokens = torch.cat([tokens, next_token], dim=1) |
| | |
| | |
| | if next_token.item() == tokenizer.eot_token: |
| | break |
| | |
| | |
| | return tokenizer.decode(tokens[0].tolist()) |
| |
|
| |
|
| | @torch.no_grad() |
| | def analyze_geometry( |
| | model: KSimplexLM, |
| | tokenizer: tiktoken.Encoding, |
| | text: str, |
| | device: str = None, |
| | ) -> dict: |
| | """ |
| | Analyze geometric properties of text encoding. |
| | |
| | Args: |
| | model: KSimplexLM model |
| | tokenizer: tiktoken encoding |
| | text: Input text |
| | device: Device |
| | |
| | Returns: |
| | Dictionary with geometric statistics |
| | """ |
| | if device is None: |
| | device = next(model.parameters()).device |
| | |
| | tokens = tokenizer.encode(text) |
| | tokens = torch.tensor([tokens], dtype=torch.long, device=device) |
| | |
| | _, geo_info = model(tokens) |
| | |
| | stats = {} |
| | for k, (vol2, d2) in enumerate(zip(geo_info['vol2'], geo_info['d2']), 1): |
| | vol2_np = vol2.cpu().numpy() |
| | d2_np = d2.cpu().numpy() |
| | |
| | stats[f'k{k}'] = { |
| | 'vol2_mean': float(vol2_np.mean()), |
| | 'vol2_std': float(vol2_np.std()), |
| | 'vol2_min': float(vol2_np.min()), |
| | 'vol2_max': float(vol2_np.max()), |
| | 'validity_rate': float((vol2_np > 0).mean()), |
| | 'd2_mean': float(d2_np.mean()), |
| | } |
| | |
| | return stats |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='K-Simplex LLM Inference') |
| | parser.add_argument('--checkpoint', type=str, help='Path to checkpoint file') |
| | parser.add_argument('--repo', type=str, default='AbstractPhil/ksimplex-llm-prototype', |
| | help='HuggingFace repo ID') |
| | parser.add_argument('--prompt', type=str, default='ROMEO: ', |
| | help='Text prompt') |
| | parser.add_argument('--max_tokens', type=int, default=100, |
| | help='Maximum tokens to generate') |
| | parser.add_argument('--temperature', type=float, default=0.8, |
| | help='Sampling temperature') |
| | parser.add_argument('--top_k', type=int, default=50, |
| | help='Top-k sampling') |
| | parser.add_argument('--top_p', type=float, default=0.9, |
| | help='Nucleus sampling threshold') |
| | parser.add_argument('--analyze', action='store_true', |
| | help='Analyze geometric properties instead of generating') |
| | |
| | args = parser.parse_args() |
| | |
| | print("Loading model...") |
| | model, tokenizer = load_model( |
| | checkpoint_path=args.checkpoint, |
| | repo_id=args.repo if not args.checkpoint else None, |
| | ) |
| | print(f"Model loaded on {next(model.parameters()).device}") |
| | |
| | if args.analyze: |
| | print(f"\nAnalyzing: {args.prompt}") |
| | stats = analyze_geometry(model, tokenizer, args.prompt) |
| | for k, kstats in stats.items(): |
| | print(f"\n{k}:") |
| | for name, value in kstats.items(): |
| | print(f" {name}: {value:.6f}") |
| | else: |
| | print(f"\nGenerating from: {args.prompt}") |
| | text = generate( |
| | model, tokenizer, args.prompt, |
| | max_tokens=args.max_tokens, |
| | temperature=args.temperature, |
| | top_k=args.top_k, |
| | top_p=args.top_p, |
| | ) |
| | print("\n" + "=" * 60) |
| | print(text) |
| | print("=" * 60) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |