| """
|
| MPS optimizations for Vortex model on Apple Silicon.
|
| Uses PyTorch MPS backend with MPS-compatible ops only.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| from typing import Optional, Dict, Any
|
|
|
|
|
| def optimize_for_mps(
|
| model: nn.Module,
|
| config: Dict,
|
| use_sdpa: bool = True,
|
| ) -> nn.Module:
|
| """
|
| Apply MPS optimizations to model.
|
|
|
| Args:
|
| model: VortexModel
|
| config: Model config
|
| use_sdpa: Use PyTorch scaled dot product attention (MPS compatible)
|
|
|
| Returns:
|
| Optimized model
|
| """
|
| device = torch.device("mps")
|
|
|
|
|
| model = model.to(device)
|
|
|
|
|
| dtype_str = config.get("dtype", "bfloat16")
|
| if dtype_str == "bfloat16":
|
|
|
| dtype = torch.float16
|
| else:
|
| dtype = torch.float32
|
|
|
| model = model.to(dtype)
|
|
|
|
|
| if use_sdpa:
|
| model = _apply_sdpa(model)
|
| print("Applied PyTorch SDPA for MPS")
|
|
|
| return model
|
|
|
|
|
| def _apply_sdpa(model: nn.Module) -> nn.Module:
|
| """
|
| Replace custom attention with PyTorch SDPA.
|
| SDPA is optimized for MPS backend.
|
| """
|
| for name, module in model.named_modules():
|
| if hasattr(module, 'attn') and hasattr(module.attn, 'forward_optimized'):
|
|
|
| original_forward = module.attn.forward
|
|
|
| def sdpa_forward(self, x, *args, **kwargs):
|
| return self._standard_attention(x, kwargs.get('attention_mask'))
|
|
|
| module.attn.forward = sdpa_forward.__get__(module.attn, type(module.attn))
|
|
|
| return model
|
|
|
|
|
| def get_mps_memory_usage() -> Dict[str, float]:
|
| """Get current MPS memory usage in GB."""
|
| if not torch.backends.mps.is_available():
|
| return {"error": "MPS not available"}
|
|
|
|
|
| import psutil
|
| process = psutil.Process()
|
| memory_info = process.memory_info()
|
|
|
| return {
|
| "rss_gb": memory_info.rss / 1e9,
|
| "vms_gb": memory_info.vms / 1e9,
|
| }
|
|
|
|
|
| def profile_model_mps(
|
| model: nn.Module,
|
| input_ids: torch.Tensor,
|
| num_warmup: int = 10,
|
| num_runs: int = 50,
|
| ) -> Dict[str, float]:
|
| """
|
| Profile model performance on MPS.
|
|
|
| Args:
|
| model: Model to profile
|
| input_ids: Example input
|
| num_warmup: Number of warmup runs
|
| num_runs: Number of profiling runs
|
|
|
| Returns:
|
| Dictionary with timing statistics
|
| """
|
| model.eval()
|
| device = next(model.parameters()).device
|
| input_ids = input_ids.to(device)
|
|
|
|
|
| with torch.no_grad():
|
| for _ in range(num_warmup):
|
| _ = model(input_ids)
|
|
|
| if device.type == "mps":
|
| torch.mps.synchronize()
|
|
|
|
|
| if device.type == "mps":
|
| torch.mps.synchronize()
|
| import time
|
| start = time.time()
|
|
|
| with torch.no_grad():
|
| for _ in range(num_runs):
|
| _ = model(input_ids)
|
| if device.type == "mps":
|
| torch.mps.synchronize()
|
|
|
| elapsed = time.time() - start
|
|
|
| avg_time = elapsed / num_runs
|
| tokens_per_sec = input_ids.shape[1] / avg_time
|
|
|
| return {
|
| "avg_time_sec": avg_time,
|
| "tokens_per_sec": tokens_per_sec,
|
| }
|
|
|
|
|
| def test_mps_optimize():
|
| """Test MPS optimizations."""
|
| if not torch.backends.mps.is_available():
|
| print("MPS not available, skipping test")
|
| return
|
|
|
| from models.vortex_model import VortexModel
|
| from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
|
|
| config = VORTEX_7B_CONFIG.copy()
|
| config["d_model"] = 512
|
| config["num_layers"] = 2
|
| config["num_heads"] = 8
|
| config["vocab_size"] = 1000
|
|
|
| model = VortexModel(config)
|
| print(f"Model parameters: {model.get_num_params():,}")
|
|
|
|
|
| model = optimize_for_mps(model, config, use_sdpa=True)
|
|
|
|
|
| batch_size = 2
|
| seq_len = 128
|
| input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).to("mps")
|
|
|
| with torch.no_grad():
|
| output = model(input_ids)
|
| logits = output["logits"]
|
|
|
| print(f"Output shape: {logits.shape}")
|
| print("MPS optimize test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_mps_optimize()
|
|
|