| """ |
| Lion Optimizer for BitTransformerLM Extensions |
| ============================================== |
| |
| Implementation of the Lion optimizer (EvoLved Sign Momentum). |
| Based on "Symbolic Discovery of Optimization Algorithms" research. |
| |
| Key features: |
| - Sign-based momentum updates |
| - Extremely memory efficient (only stores momentum) |
| - Often outperforms Adam/AdamW with larger learning rates |
| - Compatible with BitTransformerLM's training infrastructure |
| """ |
|
|
| import torch |
| from torch.optim.optimizer import Optimizer |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
| class Lion(Optimizer): |
| """ |
| Lion optimizer implementation. |
| |
| Lion uses the sign of the interpolated momentum for parameter updates, |
| making it very memory efficient while maintaining competitive performance. |
| |
| Args: |
| params: Iterable of parameters to optimize |
| lr: Learning rate (default: 1e-4, typically needs to be smaller than Adam) |
| betas: Coefficients for computing momentum (default: (0.9, 0.99)) |
| weight_decay: Weight decay coefficient (default: 0.0) |
| eps: Small constant for numerical stability (default: 1e-8) |
| maximize: Whether to maximize the objective (default: False) |
| """ |
| |
| def __init__( |
| self, |
| params, |
| lr: float = 1e-4, |
| betas: Tuple[float, float] = (0.9, 0.99), |
| weight_decay: float = 0.0, |
| eps: float = 1e-8, |
| maximize: bool = False, |
| ): |
| if not 0.0 <= lr: |
| raise ValueError(f"Invalid learning rate: {lr}") |
| if not 0.0 <= betas[0] < 1.0: |
| raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") |
| if not 0.0 <= betas[1] < 1.0: |
| raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") |
| if not 0.0 <= weight_decay: |
| raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
| if not 0.0 <= eps: |
| raise ValueError(f"Invalid epsilon value: {eps}") |
| |
| defaults = dict( |
| lr=lr, |
| betas=betas, |
| weight_decay=weight_decay, |
| eps=eps, |
| maximize=maximize, |
| ) |
| super().__init__(params, defaults) |
| |
| @torch.no_grad() |
| def step(self, closure=None): |
| """Perform a single optimization step.""" |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
| |
| for group in self.param_groups: |
| for p in group["params"]: |
| if p.grad is None: |
| continue |
| |
| grad = p.grad |
| if group["maximize"]: |
| grad = -grad |
| |
| if grad.dtype in {torch.float16, torch.bfloat16}: |
| grad = grad.float() |
| |
| state = self.state[p] |
| |
| |
| if len(state) == 0: |
| state["momentum"] = torch.zeros_like(p, memory_format=torch.preserve_format) |
| |
| momentum = state["momentum"] |
| beta1, beta2 = group["betas"] |
| |
| |
| if group["weight_decay"] != 0: |
| p.mul_(1 - group["lr"] * group["weight_decay"]) |
| |
| |
| |
| interpolated = momentum.mul(beta1).add_(grad, alpha=1 - beta1) |
| |
| |
| |
| p.add_(torch.sign(interpolated), alpha=-group["lr"]) |
| |
| |
| |
| momentum.mul_(beta2).add_(grad, alpha=1 - beta2) |
| |
| return loss |
|
|
|
|
| def configure_lion_optimizer( |
| model: torch.nn.Module, |
| lr: float = 1e-4, |
| betas: Tuple[float, float] = (0.9, 0.99), |
| weight_decay: float = 0.01, |
| total_steps: Optional[int] = None, |
| warmup_ratio: float = 0.1, |
| **lion_kwargs |
| ) -> Tuple[Lion, Optional[torch.optim.lr_scheduler._LRScheduler]]: |
| """ |
| Configure Lion optimizer with OneCycle learning rate schedule. |
| |
| This function provides a drop-in replacement for BitTransformerLM's |
| configure_optimizer function, using Lion instead of AdamW. |
| |
| Note: Lion typically works well with learning rates about 3-10x smaller |
| than Adam/AdamW, but higher weight decay (0.01-0.1). |
| |
| Args: |
| model: PyTorch model to optimize |
| lr: Peak learning rate (typically smaller than Adam) |
| betas: Beta coefficients for momentum computation |
| weight_decay: Weight decay coefficient (can be higher than Adam) |
| total_steps: Total training steps for OneCycle schedule |
| warmup_ratio: Fraction of steps for warmup |
| **lion_kwargs: Additional arguments for Lion optimizer |
| |
| Returns: |
| Tuple of (optimizer, scheduler) |
| """ |
| |
| decay_params = [] |
| no_decay_params = [] |
| |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| |
| if param.dim() >= 2: |
| decay_params.append(param) |
| else: |
| no_decay_params.append(param) |
| |
| param_groups = [ |
| {"params": decay_params, "weight_decay": weight_decay}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ] |
| |
| optimizer = Lion( |
| param_groups, |
| lr=lr, |
| betas=betas, |
| **lion_kwargs |
| ) |
| |
| scheduler = None |
| if total_steps is not None and total_steps > 0: |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, |
| max_lr=lr, |
| total_steps=total_steps, |
| pct_start=warmup_ratio, |
| anneal_strategy='cos', |
| cycle_momentum=False, |
| div_factor=25.0, |
| final_div_factor=1e4, |
| ) |
| |
| return optimizer, scheduler |
|
|
|
|
| def create_lion_training_config( |
| lr: float = 1e-4, |
| betas: Tuple[float, float] = (0.9, 0.99), |
| weight_decay: float = 0.01, |
| **kwargs |
| ) -> Dict[str, Any]: |
| """ |
| Create a training configuration dictionary for Lion optimizer. |
| |
| This can be used with BitTransformerLM's training scripts by passing |
| the config to the training loop. |
| |
| Args: |
| lr: Learning rate |
| betas: Beta coefficients for momentum |
| weight_decay: Weight decay coefficient |
| **kwargs: Additional configuration options |
| |
| Returns: |
| Dictionary containing training configuration |
| """ |
| config = { |
| "optimizer_type": "lion", |
| "optimizer_config": { |
| "lr": lr, |
| "betas": betas, |
| "weight_decay": weight_decay, |
| **kwargs |
| }, |
| "scheduler_type": "onecycle", |
| } |
| |
| return config |
|
|
|
|
| class AdaptiveLion(Lion): |
| """ |
| Enhanced Lion optimizer with adaptive learning rate scaling. |
| |
| This variant automatically adjusts the learning rate based on the |
| magnitude of gradients and momentum, potentially improving stability. |
| """ |
| |
| def __init__( |
| self, |
| params, |
| lr: float = 1e-4, |
| betas: Tuple[float, float] = (0.9, 0.99), |
| weight_decay: float = 0.0, |
| eps: float = 1e-8, |
| maximize: bool = False, |
| adaptive_scale: float = 0.1, |
| min_scale: float = 0.01, |
| max_scale: float = 10.0, |
| ): |
| """ |
| Args: |
| adaptive_scale: Scaling factor for adaptive adjustment |
| min_scale: Minimum learning rate scale |
| max_scale: Maximum learning rate scale |
| """ |
| self.adaptive_scale = adaptive_scale |
| self.min_scale = min_scale |
| self.max_scale = max_scale |
| |
| super().__init__(params, lr, betas, weight_decay, eps, maximize) |
| |
| @torch.no_grad() |
| def step(self, closure=None): |
| """Perform optimization step with adaptive scaling.""" |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
| |
| for group in self.param_groups: |
| for p in group["params"]: |
| if p.grad is None: |
| continue |
| |
| grad = p.grad |
| if group["maximize"]: |
| grad = -grad |
| |
| if grad.dtype in {torch.float16, torch.bfloat16}: |
| grad = grad.float() |
| |
| state = self.state[p] |
| |
| if len(state) == 0: |
| state["momentum"] = torch.zeros_like(p, memory_format=torch.preserve_format) |
| state["step"] = 0 |
| |
| momentum = state["momentum"] |
| state["step"] += 1 |
| beta1, beta2 = group["betas"] |
| |
| |
| grad_norm = grad.norm().item() |
| momentum_norm = momentum.norm().item() |
| |
| |
| if momentum_norm > 1e-8: |
| scale = 1.0 + self.adaptive_scale * (grad_norm / momentum_norm - 1.0) |
| scale = torch.clamp(torch.tensor(scale), self.min_scale, self.max_scale).item() |
| else: |
| scale = 1.0 |
| |
| adaptive_lr = group["lr"] * scale |
| |
| |
| if group["weight_decay"] != 0: |
| p.mul_(1 - adaptive_lr * group["weight_decay"]) |
| |
| |
| interpolated = momentum.mul(beta1).add_(grad, alpha=1 - beta1) |
| p.add_(torch.sign(interpolated), alpha=-adaptive_lr) |
| momentum.mul_(beta2).add_(grad, alpha=1 - beta2) |
| |
| return loss |
|
|
|
|
| def configure_adaptive_lion_optimizer( |
| model: torch.nn.Module, |
| lr: float = 1e-4, |
| adaptive_scale: float = 0.1, |
| **kwargs |
| ) -> Tuple[AdaptiveLion, Optional[torch.optim.lr_scheduler._LRScheduler]]: |
| """Configure AdaptiveLion optimizer with learning rate scheduling.""" |
| |
| decay_params = [] |
| no_decay_params = [] |
| |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if param.dim() >= 2: |
| decay_params.append(param) |
| else: |
| no_decay_params.append(param) |
| |
| param_groups = [ |
| {"params": decay_params, "weight_decay": kwargs.get("weight_decay", 0.01)}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ] |
| |
| optimizer = AdaptiveLion( |
| param_groups, |
| lr=lr, |
| adaptive_scale=adaptive_scale, |
| **{k: v for k, v in kwargs.items() if k != "weight_decay"} |
| ) |
| |
| scheduler = None |
| total_steps = kwargs.get("total_steps") |
| if total_steps is not None and total_steps > 0: |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, |
| max_lr=lr, |
| total_steps=total_steps, |
| pct_start=kwargs.get("warmup_ratio", 0.1), |
| anneal_strategy='cos', |
| cycle_momentum=False, |
| div_factor=25.0, |
| final_div_factor=1e4, |
| ) |
| |
| return optimizer, scheduler |
|
|
|
|
| |
| def integrate_with_bittransformerlm(): |
| """ |
| Example of how to integrate Lion optimizer with BitTransformerLM training. |
| |
| Usage: |
| from BTLM_Extensions.lion_optimizer import configure_lion_optimizer |
| |
| # Replace the standard optimizer configuration |
| # Note: Lion typically needs smaller learning rates than Adam |
| optimizer, scheduler = configure_lion_optimizer( |
| model, lr=1e-4, weight_decay=0.01, total_steps=1000 |
| ) |
| |
| # Use in training loop |
| train_loop(model, data, optimizer=optimizer, scheduler=scheduler) |
| |
| # For adaptive version: |
| from BTLM_Extensions.lion_optimizer import configure_adaptive_lion_optimizer |
| |
| optimizer, scheduler = configure_adaptive_lion_optimizer( |
| model, lr=1e-4, adaptive_scale=0.1, total_steps=1000 |
| ) |
| """ |
| pass |
|
|
|
|
| if __name__ == "__main__": |
| |
| import torch.nn as nn |
| |
| model = nn.Sequential( |
| nn.Linear(10, 20), |
| nn.ReLU(), |
| nn.Linear(20, 1) |
| ) |
| |
| print("Testing standard Lion optimizer...") |
| optimizer, scheduler = configure_lion_optimizer(model, lr=1e-4, total_steps=100) |
| |
| |
| x = torch.randn(32, 10) |
| y = torch.randn(32, 1) |
| |
| pred = model(x) |
| loss = nn.functional.mse_loss(pred, y) |
| initial_loss = loss.item() |
| loss.backward() |
| |
| optimizer.step() |
| if scheduler: |
| scheduler.step() |
| |
| print(f"Initial loss: {initial_loss:.4f}") |
| |
| |
| print("Testing Adaptive Lion optimizer...") |
| model2 = nn.Sequential( |
| nn.Linear(10, 20), |
| nn.ReLU(), |
| nn.Linear(20, 1) |
| ) |
| |
| optimizer2, scheduler2 = configure_adaptive_lion_optimizer( |
| model2, lr=1e-4, adaptive_scale=0.1, total_steps=100 |
| ) |
| |
| pred2 = model2(x) |
| loss2 = nn.functional.mse_loss(pred2, y) |
| loss2.backward() |
| optimizer2.step() |
| if scheduler2: |
| scheduler2.step() |
| |
| print("Lion optimizers test completed successfully!") |
| print(f"Standard Lion loss: {initial_loss:.4f}") |
| print(f"Adaptive Lion loss: {loss2.item():.4f}") |