| """ |
| Muon Optimizer for BitTransformerLM Extensions |
| ============================================== |
| |
| Implementation of the Muon optimizer with orthogonal momentum updates. |
| Based on "Muon: Momentum Orthogonalized by Newton's method" research. |
| |
| Key features: |
| - Orthogonal momentum updates |
| - Better convergence properties than Adam/AdamW |
| - Memory efficient implementation |
| - Compatible with BitTransformerLM's training infrastructure |
| """ |
|
|
| import math |
| import torch |
| from torch.optim.optimizer import Optimizer |
| from typing import Any, Dict, List, Optional, Tuple, Union |
| import warnings |
|
|
|
|
| class Muon(Optimizer): |
| """ |
| Muon optimizer with orthogonal momentum updates. |
| |
| This implementation provides momentum updates that are orthogonalized using |
| Newton's method, leading to more stable training dynamics. |
| |
| Args: |
| params: Iterable of parameters to optimize |
| lr: Learning rate (default: 1e-3) |
| momentum: Momentum factor (default: 0.95) |
| nesterov: Enable Nesterov momentum (default: False) |
| backend: Backend for orthogonalization ('newtonschulz' or 'svd') |
| update_period: Period for updating orthogonalization (default: 1) |
| rank_deficiency_threshold: Threshold for rank deficiency detection |
| eps: Small constant for numerical stability (default: 1e-8) |
| weight_decay: Weight decay coefficient (default: 0.0) |
| """ |
| |
| def __init__( |
| self, |
| params, |
| lr: float = 1e-3, |
| momentum: float = 0.95, |
| nesterov: bool = False, |
| backend: str = "newtonschulz", |
| update_period: int = 1, |
| rank_deficiency_threshold: float = 1e-6, |
| eps: float = 1e-8, |
| weight_decay: float = 0.0, |
| ): |
| if not 0.0 <= lr: |
| raise ValueError(f"Invalid learning rate: {lr}") |
| if not 0.0 <= momentum <= 1.0: |
| raise ValueError(f"Invalid momentum value: {momentum}") |
| if not 0.0 <= weight_decay: |
| raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
| if backend not in ["newtonschulz", "svd"]: |
| raise ValueError(f"Invalid backend: {backend}") |
| |
| defaults = dict( |
| lr=lr, |
| momentum=momentum, |
| nesterov=nesterov, |
| backend=backend, |
| update_period=update_period, |
| rank_deficiency_threshold=rank_deficiency_threshold, |
| eps=eps, |
| weight_decay=weight_decay, |
| ) |
| super().__init__(params, defaults) |
| |
| def _orthogonalize_newtonschulz(self, matrix: torch.Tensor, num_iterations: int = 5) -> torch.Tensor: |
| """Orthogonalize matrix using Newton-Schulz iteration.""" |
| |
| original_shape = matrix.shape |
| if matrix.dim() > 2: |
| matrix = matrix.view(-1, matrix.shape[-1]) |
| |
| if matrix.shape[0] >= matrix.shape[1]: |
| |
| X = matrix.clone() |
| for _ in range(num_iterations): |
| A = X.T @ X |
| X = X @ (1.5 * torch.eye(A.shape[0], device=A.device, dtype=A.dtype) - 0.5 * A) |
| else: |
| |
| X = matrix.clone() |
| for _ in range(num_iterations): |
| A = X @ X.T |
| X = (1.5 * torch.eye(A.shape[0], device=A.device, dtype=A.dtype) - 0.5 * A) @ X |
| |
| return X.view(original_shape) |
| |
| def _orthogonalize_svd(self, matrix: torch.Tensor) -> torch.Tensor: |
| """Orthogonalize matrix using SVD decomposition.""" |
| original_shape = matrix.shape |
| if matrix.dim() > 2: |
| matrix = matrix.view(-1, matrix.shape[-1]) |
| |
| try: |
| U, _, Vt = torch.linalg.svd(matrix, full_matrices=False) |
| orthogonal = U @ Vt |
| return orthogonal.view(original_shape) |
| except torch._C._LinAlgError: |
| |
| warnings.warn("SVD failed, falling back to Newton-Schulz") |
| return self._orthogonalize_newtonschulz(matrix) |
| |
| @torch.no_grad() |
| def step(self, closure=None): |
| """Perform a single optimization step.""" |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
| |
| for group in self.param_groups: |
| for p in group["params"]: |
| if p.grad is None: |
| continue |
| |
| grad = p.grad |
| if grad.dtype in {torch.float16, torch.bfloat16}: |
| grad = grad.float() |
| |
| state = self.state[p] |
| |
| |
| if len(state) == 0: |
| state["step"] = 0 |
| state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) |
| |
| momentum_buffer = state["momentum_buffer"] |
| state["step"] += 1 |
| |
| |
| if group["weight_decay"] != 0: |
| grad = grad.add(p, alpha=group["weight_decay"]) |
| |
| |
| momentum_buffer.mul_(group["momentum"]).add_(grad) |
| |
| |
| if state["step"] % group["update_period"] == 0 and momentum_buffer.numel() > 1: |
| |
| if momentum_buffer.dim() >= 2 and min(momentum_buffer.shape[-2:]) > 1: |
| if group["backend"] == "newtonschulz": |
| orthogonal_momentum = self._orthogonalize_newtonschulz(momentum_buffer) |
| else: |
| orthogonal_momentum = self._orthogonalize_svd(momentum_buffer) |
| |
| |
| rank_ratio = torch.linalg.matrix_norm(orthogonal_momentum) / torch.linalg.matrix_norm(momentum_buffer) |
| if rank_ratio < group["rank_deficiency_threshold"]: |
| warnings.warn("Detected rank deficiency in momentum buffer") |
| else: |
| momentum_buffer.copy_(orthogonal_momentum) |
| |
| |
| if group["nesterov"]: |
| update = grad.add(momentum_buffer, alpha=group["momentum"]) |
| else: |
| update = momentum_buffer |
| |
| |
| p.add_(update, alpha=-group["lr"]) |
| |
| return loss |
|
|
|
|
| def configure_muon_optimizer( |
| model: torch.nn.Module, |
| lr: float = 1e-3, |
| momentum: float = 0.95, |
| weight_decay: float = 0.01, |
| total_steps: Optional[int] = None, |
| warmup_ratio: float = 0.1, |
| nesterov: bool = False, |
| backend: str = "newtonschulz", |
| **muon_kwargs |
| ) -> Tuple[Muon, Optional[torch.optim.lr_scheduler._LRScheduler]]: |
| """ |
| Configure Muon optimizer with OneCycle learning rate schedule. |
| |
| This function provides a drop-in replacement for BitTransformerLM's |
| configure_optimizer function, using Muon instead of AdamW. |
| |
| Args: |
| model: PyTorch model to optimize |
| lr: Peak learning rate |
| momentum: Momentum factor for Muon |
| weight_decay: Weight decay coefficient |
| total_steps: Total training steps for OneCycle schedule |
| warmup_ratio: Fraction of steps for warmup |
| nesterov: Enable Nesterov momentum |
| backend: Orthogonalization backend |
| **muon_kwargs: Additional arguments for Muon optimizer |
| |
| Returns: |
| Tuple of (optimizer, scheduler) |
| """ |
| |
| decay_params = [] |
| no_decay_params = [] |
| |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| |
| if param.dim() >= 2: |
| decay_params.append(param) |
| else: |
| no_decay_params.append(param) |
| |
| param_groups = [ |
| {"params": decay_params, "weight_decay": weight_decay}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ] |
| |
| optimizer = Muon( |
| param_groups, |
| lr=lr, |
| momentum=momentum, |
| nesterov=nesterov, |
| backend=backend, |
| **muon_kwargs |
| ) |
| |
| scheduler = None |
| if total_steps is not None and total_steps > 0: |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, |
| max_lr=lr, |
| total_steps=total_steps, |
| pct_start=warmup_ratio, |
| anneal_strategy='cos', |
| cycle_momentum=False, |
| div_factor=25.0, |
| final_div_factor=1e4, |
| ) |
| |
| return optimizer, scheduler |
|
|
|
|
| def create_muon_training_config( |
| lr: float = 1e-3, |
| momentum: float = 0.95, |
| weight_decay: float = 0.01, |
| backend: str = "newtonschulz", |
| nesterov: bool = False, |
| **kwargs |
| ) -> Dict[str, Any]: |
| """ |
| Create a training configuration dictionary for Muon optimizer. |
| |
| This can be used with BitTransformerLM's training scripts by passing |
| the config to the training loop. |
| |
| Args: |
| lr: Learning rate |
| momentum: Momentum factor |
| weight_decay: Weight decay coefficient |
| backend: Orthogonalization backend |
| nesterov: Enable Nesterov momentum |
| **kwargs: Additional configuration options |
| |
| Returns: |
| Dictionary containing training configuration |
| """ |
| config = { |
| "optimizer_type": "muon", |
| "optimizer_config": { |
| "lr": lr, |
| "momentum": momentum, |
| "weight_decay": weight_decay, |
| "backend": backend, |
| "nesterov": nesterov, |
| **kwargs |
| }, |
| "scheduler_type": "onecycle", |
| } |
| |
| return config |
|
|
|
|
| |
| def integrate_with_bittransformerlm(): |
| """ |
| Example of how to integrate Muon optimizer with BitTransformerLM training. |
| |
| Usage: |
| from BTLM_Extensions.muon_optimizer import configure_muon_optimizer |
| |
| # Replace the standard optimizer configuration |
| optimizer, scheduler = configure_muon_optimizer( |
| model, lr=1e-3, momentum=0.95, total_steps=1000 |
| ) |
| |
| # Use in training loop |
| train_loop(model, data, optimizer=optimizer, scheduler=scheduler) |
| """ |
| pass |
|
|
|
|
| if __name__ == "__main__": |
| |
| import torch.nn as nn |
| |
| model = nn.Sequential( |
| nn.Linear(10, 20), |
| nn.ReLU(), |
| nn.Linear(20, 1) |
| ) |
| |
| optimizer, scheduler = configure_muon_optimizer(model, lr=1e-3, total_steps=100) |
| |
| |
| x = torch.randn(32, 10) |
| y = torch.randn(32, 1) |
| |
| pred = model(x) |
| loss = nn.functional.mse_loss(pred, y) |
| loss.backward() |
| |
| optimizer.step() |
| if scheduler: |
| scheduler.step() |
| |
| print("Muon optimizer test completed successfully!") |
| print(f"Loss: {loss.item():.4f}") |