"""Configuration objects for GRPO training.""" from __future__ import annotations import logging import os from dataclasses import dataclass, field from pathlib import Path _logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Device options # --------------------------------------------------------------------------- # "auto" — use GPU/MPS if available, fall back to CPU # "cpu" — force CPU (use on Mac where MPS OOMs during GRPO) # "cuda" — force CUDA (use on Colab / cloud GPU) # "mps" — force MPS (only if model fits; unlikely for GRPO) DEVICE_AUTO = "auto" DEVICE_CPU = "cpu" DEVICE_CUDA = "cuda" DEVICE_MPS = "mps" def find_project_root() -> Path: """Walk up from cwd until we find pyproject.toml.""" d = Path.cwd() for parent in [d, *d.parents]: if (parent / "pyproject.toml").exists(): return parent raise FileNotFoundError("Could not locate project root (no pyproject.toml found)") def apply_device_overrides(device: str) -> None: """Set environment/backend flags so PyTorch and HuggingFace respect *device*. Call this before importing transformers or loading models. Why this exists: GRPO generates multiple completions per prompt, so peak memory is several times the model size. On Mac (MPS, typically 16 GB shared), even a 0.6B model OOMs. Forcing CPU avoids the crash at the cost of speed. On Colab/cloud, "auto" or "cuda" is the right choice. """ if device == DEVICE_AUTO: return if device == DEVICE_CPU: os.environ["CUDA_VISIBLE_DEVICES"] = "" try: import torch torch.backends.mps.is_available = lambda: False # type: ignore[assignment] except ImportError: pass _logger.info("Device forced to CPU — MPS and CUDA disabled") return if device == DEVICE_CUDA: try: import torch torch.backends.mps.is_available = lambda: False # type: ignore[assignment] except ImportError: pass _logger.info("Device forced to CUDA — MPS disabled") return # "mps" — no overrides needed, PyTorch will use MPS if available @dataclass class GRPOConfig: """Configuration for GRPO training on SQLEnv. Parameters ---------- questions_path Path to the training questions JSON file. db_dir Directory containing SQLite databases. output_dir Directory where checkpoints and outputs are written. device Device policy: "auto", "cpu", "cuda", or "mps". Use "cpu" on Mac (MPS OOMs with GRPO). Use "auto" or "cuda" on Colab / cloud GPU. """ questions_path: str db_dir: str output_dir: str model_name: str = "Qwen/Qwen3-0.6B" device: str = DEVICE_AUTO max_new_tokens: int = 256 num_train_epochs: int = 1 per_device_train_batch_size: int = 2 gradient_accumulation_steps: int = 4 learning_rate: float = 5e-6 num_generations: int = 4 step_budget: int = 10 difficulty_filter: list[str] = field(default_factory=lambda: ["easy", "medium"]) seed: int = 42 logging_steps: int = 10 def __post_init__(self) -> None: valid_devices = {DEVICE_AUTO, DEVICE_CPU, DEVICE_CUDA, DEVICE_MPS} if self.device not in valid_devices: msg = f"device must be one of {valid_devices}, got '{self.device}'" raise ValueError(msg) if self.max_new_tokens <= 0: raise ValueError("max_new_tokens must be > 0") if self.num_train_epochs <= 0: raise ValueError("num_train_epochs must be > 0") if self.per_device_train_batch_size <= 0: raise ValueError("per_device_train_batch_size must be > 0") if self.gradient_accumulation_steps <= 0: raise ValueError("gradient_accumulation_steps must be > 0") if self.learning_rate <= 0: raise ValueError("learning_rate must be > 0") if self.num_generations <= 0: raise ValueError("num_generations must be > 0") if self.step_budget < 0: raise ValueError("step_budget must be >= 0") if self.logging_steps <= 0: raise ValueError("logging_steps must be > 0") if not self.difficulty_filter: raise ValueError("difficulty_filter must not be empty")