sql_env / training /config.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Configuration objects for GRPO training."""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
_logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Device options
# ---------------------------------------------------------------------------
# "auto" β€” use GPU/MPS if available, fall back to CPU
# "cpu" β€” force CPU (use on Mac where MPS OOMs during GRPO)
# "cuda" β€” force CUDA (use on Colab / cloud GPU)
# "mps" β€” force MPS (only if model fits; unlikely for GRPO)
DEVICE_AUTO = "auto"
DEVICE_CPU = "cpu"
DEVICE_CUDA = "cuda"
DEVICE_MPS = "mps"
def find_project_root() -> Path:
"""Walk up from cwd until we find pyproject.toml."""
d = Path.cwd()
for parent in [d, *d.parents]:
if (parent / "pyproject.toml").exists():
return parent
raise FileNotFoundError("Could not locate project root (no pyproject.toml found)")
def apply_device_overrides(device: str) -> None:
"""Set environment/backend flags so PyTorch and HuggingFace respect *device*.
Call this before importing transformers or loading models.
Why this exists: GRPO generates multiple completions per prompt, so peak
memory is several times the model size. On Mac (MPS, typically 16 GB
shared), even a 0.6B model OOMs. Forcing CPU avoids the crash at the
cost of speed. On Colab/cloud, "auto" or "cuda" is the right choice.
"""
if device == DEVICE_AUTO:
return
if device == DEVICE_CPU:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
try:
import torch
torch.backends.mps.is_available = lambda: False # type: ignore[assignment]
except ImportError:
pass
_logger.info("Device forced to CPU β€” MPS and CUDA disabled")
return
if device == DEVICE_CUDA:
try:
import torch
torch.backends.mps.is_available = lambda: False # type: ignore[assignment]
except ImportError:
pass
_logger.info("Device forced to CUDA β€” MPS disabled")
return
# "mps" β€” no overrides needed, PyTorch will use MPS if available
@dataclass
class GRPOConfig:
"""Configuration for GRPO training on SQLEnv.
Parameters
----------
questions_path
Path to the training questions JSON file.
db_dir
Directory containing SQLite databases.
output_dir
Directory where checkpoints and outputs are written.
device
Device policy: "auto", "cpu", "cuda", or "mps".
Use "cpu" on Mac (MPS OOMs with GRPO).
Use "auto" or "cuda" on Colab / cloud GPU.
"""
questions_path: str
db_dir: str
output_dir: str
model_name: str = "Qwen/Qwen3-0.6B"
device: str = DEVICE_AUTO
max_new_tokens: int = 256
num_train_epochs: int = 1
per_device_train_batch_size: int = 2
gradient_accumulation_steps: int = 4
learning_rate: float = 5e-6
num_generations: int = 4
step_budget: int = 10
difficulty_filter: list[str] = field(default_factory=lambda: ["easy", "medium"])
seed: int = 42
logging_steps: int = 10
def __post_init__(self) -> None:
valid_devices = {DEVICE_AUTO, DEVICE_CPU, DEVICE_CUDA, DEVICE_MPS}
if self.device not in valid_devices:
msg = f"device must be one of {valid_devices}, got '{self.device}'"
raise ValueError(msg)
if self.max_new_tokens <= 0:
raise ValueError("max_new_tokens must be > 0")
if self.num_train_epochs <= 0:
raise ValueError("num_train_epochs must be > 0")
if self.per_device_train_batch_size <= 0:
raise ValueError("per_device_train_batch_size must be > 0")
if self.gradient_accumulation_steps <= 0:
raise ValueError("gradient_accumulation_steps must be > 0")
if self.learning_rate <= 0:
raise ValueError("learning_rate must be > 0")
if self.num_generations <= 0:
raise ValueError("num_generations must be > 0")
if self.step_budget < 0:
raise ValueError("step_budget must be >= 0")
if self.logging_steps <= 0:
raise ValueError("logging_steps must be > 0")
if not self.difficulty_filter:
raise ValueError("difficulty_filter must not be empty")