| """Train and run quantized self-driving lab models with Unsloth.
|
|
|
| This keeps the same OpenEnv prompt + reward wiring as `training_script.py`,
|
| but arranges the Unsloth path in the more typical pattern:
|
| 1. patch GRPO support
|
| 2. load a quantized model
|
| 3. apply LoRA adapters
|
| 4. train with an explicit OpenEnv reward function
|
|
|
| NOTE: Unsloth must be imported before trl, transformers, peft. Import this
|
| module before training_script.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import random
|
| from pathlib import Path
|
| from typing import Any, Dict, Optional, Sequence
|
|
|
|
|
| import unsloth
|
|
|
| import training_script as base
|
|
|
| DEFAULT_OUTPUT_DIR = "training/grpo-unsloth-output"
|
| DEFAULT_MAX_SEQ_LENGTH = 2048
|
| DEFAULT_LORA_R = 16
|
| DEFAULT_LORA_ALPHA = 16
|
| DEFAULT_LORA_DROPOUT = 0.0
|
| LORA_TARGET_MODULES = [
|
| "q_proj",
|
| "k_proj",
|
| "v_proj",
|
| "o_proj",
|
| "gate_proj",
|
| "up_proj",
|
| "down_proj",
|
| ]
|
|
|
|
|
| def require_unsloth():
|
| try:
|
| from unsloth import FastLanguageModel, PatchFastRL
|
| except ImportError as exc:
|
| msg = str(exc)
|
| if "vllm.lora" in msg or "vllm" in msg.lower():
|
| raise RuntimeError(
|
| f"Unsloth failed: {exc}. "
|
| "unsloth_zoo expects vllm.lora.models. Install a compatible vllm:\n"
|
| " pip install 'vllm==0.8.2' # requires torch 2.6\n"
|
| " pip install 'vllm==0.7.3' # alternative\n"
|
| "If torch>=2.10 conflicts, use a separate env with torch 2.6–2.8."
|
| ) from exc
|
| if "unsloth" in msg.lower():
|
| raise RuntimeError(
|
| "Unsloth is not installed. Run `uv sync` or `pip install unsloth`."
|
| ) from exc
|
| raise RuntimeError(f"Failed to import Unsloth: {exc}") from exc
|
| return FastLanguageModel, PatchFastRL
|
|
|
|
|
| def _call_unsloth_from_pretrained(FastLanguageModel, **kwargs: Any):
|
| for optional_key in ("fast_inference", "trust_remote_code"):
|
| try:
|
| return FastLanguageModel.from_pretrained(**kwargs)
|
| except TypeError as exc:
|
| if optional_key in kwargs and optional_key in str(exc):
|
| kwargs = dict(kwargs)
|
| kwargs.pop(optional_key, None)
|
| continue
|
| raise
|
| return FastLanguageModel.from_pretrained(**kwargs)
|
|
|
|
|
| def build_argument_parser() -> argparse.ArgumentParser:
|
| parser = base.build_argument_parser()
|
| parser.description = (
|
| "Train a GRPO policy with Unsloth quantized loading for faster H100 runs."
|
| )
|
| parser.set_defaults(output_dir=DEFAULT_OUTPUT_DIR)
|
| parser.add_argument(
|
| "--max-seq-length",
|
| type=int,
|
| default=DEFAULT_MAX_SEQ_LENGTH,
|
| help="Context length passed to Unsloth model loading.",
|
| )
|
| parser.add_argument(
|
| "--disable-4bit",
|
| action="store_true",
|
| help="Disable 4-bit quantized loading and use the wider base weights.",
|
| )
|
| parser.add_argument(
|
| "--lora-r",
|
| type=int,
|
| default=DEFAULT_LORA_R,
|
| help="LoRA rank used for the quantized GRPO policy.",
|
| )
|
| parser.add_argument(
|
| "--lora-alpha",
|
| type=int,
|
| default=DEFAULT_LORA_ALPHA,
|
| help="LoRA alpha used for the quantized GRPO policy.",
|
| )
|
| parser.add_argument(
|
| "--lora-dropout",
|
| type=float,
|
| default=DEFAULT_LORA_DROPOUT,
|
| help="LoRA dropout used for the quantized GRPO policy.",
|
| )
|
| parser.add_argument(
|
| "--save-merged-16bit",
|
| action="store_true",
|
| help="Also export a merged 16-bit model after training if supported.",
|
| )
|
| return parser
|
|
|
|
|
| def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
| return build_argument_parser().parse_args(argv)
|
|
|
|
|
| def make_training_args(**overrides: Any) -> argparse.Namespace:
|
| parser = build_argument_parser()
|
| defaults = vars(parser.parse_args([]))
|
| unknown = sorted(set(overrides) - set(defaults))
|
| if unknown:
|
| raise ValueError(f"Unknown training args: {', '.join(unknown)}")
|
| defaults.update(overrides)
|
| return argparse.Namespace(**defaults)
|
|
|
|
|
| def load_model_artifacts(
|
| model_id: str,
|
| *,
|
| trust_remote_code: bool,
|
| max_seq_length: int = DEFAULT_MAX_SEQ_LENGTH,
|
| load_in_4bit: bool = True,
|
| fast_inference: bool = False,
|
| prepare_for_inference: bool = False,
|
| ):
|
| FastLanguageModel, _ = require_unsloth()
|
| runtime = base.resolve_torch_runtime()
|
|
|
| print(f"Loading Unsloth tokenizer+model for {model_id} ...")
|
| model, tokenizer = _call_unsloth_from_pretrained(
|
| FastLanguageModel,
|
| model_name=model_id,
|
| max_seq_length=max_seq_length,
|
| dtype=runtime["dtype"],
|
| load_in_4bit=load_in_4bit,
|
| fast_inference=fast_inference,
|
| trust_remote_code=trust_remote_code,
|
| )
|
| if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
| if prepare_for_inference:
|
| try:
|
| FastLanguageModel.for_inference(model)
|
| except AttributeError:
|
| pass
|
|
|
| device = getattr(model, "device", None)
|
| if device is None:
|
| try:
|
| device = next(model.parameters()).device
|
| except StopIteration:
|
| device = runtime["device"]
|
| print(f"Loaded model on device: {device}")
|
| return tokenizer, model
|
|
|
|
|
| def build_openenv_reward(args: argparse.Namespace) -> base.OpenEnvReward:
|
| """Return the OpenEnv-compatible reward callable used by GRPO."""
|
| return base.OpenEnvReward(
|
| reward_backend=args.reward_backend,
|
| base_url=args.base_url,
|
| domain_randomise=args.domain_randomise,
|
| )
|
|
|
|
|
| def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]:
|
| """Build the OpenEnv rollout states that seed GRPO prompts."""
|
| scenario_names = base.selected_scenarios(args.scenario_name)
|
| examples = base.build_prompt_examples(
|
| dataset_episodes=args.dataset_episodes,
|
| rollout_steps=args.rollout_steps,
|
| collection_policy=args.collection_policy,
|
| scenario_names=scenario_names,
|
| seed=args.seed,
|
| domain_randomise=args.domain_randomise,
|
| )
|
| return {
|
| "scenario_names": scenario_names,
|
| "examples": examples,
|
| }
|
|
|
|
|
| def patch_unsloth_grpo():
|
| """Patch TRL GRPO to use Unsloth's optimized kernels."""
|
| FastLanguageModel, PatchFastRL = require_unsloth()
|
| PatchFastRL("GRPO", FastLanguageModel)
|
| return FastLanguageModel
|
|
|
|
|
| def apply_lora_adapters(FastLanguageModel, model: Any, args: argparse.Namespace) -> Any:
|
| """Apply LoRA adapters in the usual Unsloth configuration style."""
|
| return FastLanguageModel.get_peft_model(
|
| model,
|
| r=args.lora_r,
|
| target_modules=LORA_TARGET_MODULES,
|
| lora_alpha=args.lora_alpha,
|
| lora_dropout=args.lora_dropout,
|
| bias="none",
|
| use_gradient_checkpointing=True,
|
| random_state=args.seed,
|
| )
|
|
|
|
|
| def build_grpo_config(
|
| args: argparse.Namespace,
|
| runtime: Dict[str, Any],
|
| ):
|
| import inspect
|
|
|
| base._guard_invalid_torchao_version()
|
| base._guard_partial_vllm_install()
|
| from trl import GRPOConfig
|
|
|
| supported_params = set(inspect.signature(GRPOConfig.__init__).parameters)
|
| config_kwargs = {
|
| "output_dir": args.output_dir,
|
| "learning_rate": args.learning_rate,
|
| "per_device_train_batch_size": args.per_device_train_batch_size,
|
| "gradient_accumulation_steps": args.gradient_accumulation_steps,
|
| "num_generations": args.num_generations,
|
| "max_completion_length": args.max_completion_length,
|
| "num_train_epochs": args.num_train_epochs,
|
| "logging_steps": args.logging_steps,
|
| "save_steps": args.save_steps,
|
| "bf16": runtime["bf16"],
|
| "fp16": runtime["fp16"],
|
| "report_to": "none",
|
| "remove_unused_columns": False,
|
| }
|
|
|
|
|
| if "max_prompt_length" in supported_params:
|
| config_kwargs["max_prompt_length"] = None
|
| if (
|
| "max_length" in supported_params
|
| and "max_prompt_length" not in supported_params
|
| and "max_completion_length" not in supported_params
|
| ):
|
| config_kwargs["max_length"] = getattr(args, "max_prompt_length", 1024) + args.max_completion_length
|
| filtered_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_params}
|
| skipped = sorted(set(config_kwargs) - set(filtered_kwargs))
|
| if skipped:
|
| print(f"GRPOConfig compatibility: skipping unsupported fields {', '.join(skipped)}")
|
| return GRPOConfig(**filtered_kwargs)
|
|
|
|
|
| def build_unsloth_grpo_trainer(
|
| *,
|
| model: Any,
|
| tokenizer: Any,
|
| reward_func: Any,
|
| train_dataset: Any,
|
| args: argparse.Namespace,
|
| runtime: Dict[str, Any],
|
| ):
|
| base._guard_invalid_torchao_version()
|
| base._guard_partial_vllm_install()
|
| from trl import GRPOTrainer
|
|
|
| config = build_grpo_config(args, runtime)
|
| return GRPOTrainer(
|
| model=model,
|
| reward_funcs=reward_func,
|
| args=config,
|
| train_dataset=train_dataset,
|
| processing_class=tokenizer,
|
| )
|
|
|
|
|
| def generate_action_with_model(
|
| model: Any,
|
| tokenizer: Any,
|
| prompt_or_observation: str | base.ExperimentObservation,
|
| *,
|
| max_new_tokens: int = base.DEFAULT_COMPLETION_TOKEN_BUDGET,
|
| temperature: float = 0.2,
|
| top_p: float = 0.9,
|
| do_sample: bool = True,
|
| ) -> Dict[str, Any]:
|
| import torch
|
|
|
| if isinstance(prompt_or_observation, base.ExperimentObservation):
|
| prompt = base.build_training_prompt(prompt_or_observation)
|
| else:
|
| prompt = str(prompt_or_observation)
|
|
|
| model_device = getattr(model, "device", None)
|
| if model_device is None:
|
| try:
|
| model_device = next(model.parameters()).device
|
| except StopIteration:
|
| model_device = base.resolve_torch_runtime()["device"]
|
|
|
| inputs = tokenizer(prompt, return_tensors="pt")
|
| inputs = {key: value.to(model_device) for key, value in inputs.items()}
|
| prompt_tokens = inputs["input_ids"].shape[1]
|
|
|
| generation_kwargs = {
|
| "max_new_tokens": max_new_tokens,
|
| "do_sample": do_sample,
|
| "temperature": temperature,
|
| "top_p": top_p,
|
| "pad_token_id": tokenizer.pad_token_id,
|
| }
|
| with torch.no_grad():
|
| output_ids = model.generate(**inputs, **generation_kwargs)
|
|
|
| new_tokens = output_ids[0][prompt_tokens:]
|
| response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| action = base.parse_action_completion(response_text)
|
| if action is not None and isinstance(prompt_or_observation, base.ExperimentObservation):
|
| action = base.ensure_conclusion_claims(prompt_or_observation, action)
|
| return {
|
| "prompt": prompt,
|
| "response_text": response_text,
|
| "action": action,
|
| }
|
|
|
|
|
| def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
| random.seed(args.seed)
|
| runtime = base.resolve_torch_runtime()
|
|
|
| if args.load_model_only:
|
| tokenizer, model = load_model_artifacts(
|
| args.model_id,
|
| trust_remote_code=args.trust_remote_code,
|
| max_seq_length=args.max_seq_length,
|
| load_in_4bit=not args.disable_4bit,
|
| fast_inference=False,
|
| prepare_for_inference=True,
|
| )
|
| device = getattr(model, "device", "unknown")
|
| print(f"Unsloth model ready: {args.model_id}")
|
| print(f"Tokenizer vocab size: {len(tokenizer)}")
|
| print(f"Model device: {device}")
|
| print(f"Runtime device name: {runtime['device_name']}")
|
| return {
|
| "args": args,
|
| "runtime": runtime,
|
| "tokenizer": tokenizer,
|
| "model": model,
|
| }
|
|
|
| prompt_data = prepare_prompt_examples(args)
|
| scenario_names = prompt_data["scenario_names"]
|
| examples = prompt_data["examples"]
|
| env_reward = build_openenv_reward(args)
|
|
|
| if args.dry_run:
|
| base.run_dry_run_preview(examples, env_reward, args.output_dir)
|
| return {
|
| "args": args,
|
| "runtime": runtime,
|
| "scenario_names": scenario_names,
|
| "examples": examples,
|
| "reward_fn": env_reward,
|
| }
|
|
|
| from datasets import Dataset
|
|
|
| FastLanguageModel = patch_unsloth_grpo()
|
| train_dataset = Dataset.from_list(examples)
|
|
|
|
|
| tokenizer, model = load_model_artifacts(
|
| args.model_id,
|
| trust_remote_code=args.trust_remote_code,
|
| max_seq_length=args.max_seq_length,
|
| load_in_4bit=not args.disable_4bit,
|
| fast_inference=False,
|
| )
|
|
|
| model = apply_lora_adapters(FastLanguageModel, model, args)
|
|
|
| print(
|
| f"Unsloth training runtime: device={runtime['device']} "
|
| f"name={runtime['device_name']} "
|
| f"dtype={runtime['dtype']} "
|
| f"load_in_4bit={not args.disable_4bit}"
|
| )
|
| print(
|
| "OpenEnv reward: "
|
| f"backend={args.reward_backend} scenarios={len(scenario_names)} "
|
| f"examples={len(examples)}"
|
| )
|
|
|
|
|
| trainer = build_unsloth_grpo_trainer(
|
| model=model,
|
| tokenizer=tokenizer,
|
| reward_func=env_reward,
|
| train_dataset=train_dataset,
|
| args=args,
|
| runtime=runtime,
|
| )
|
|
|
|
|
| for attr in ("image_token_id", "vision_start_token_id", "vision_end_token_id"):
|
| if not hasattr(trainer, attr):
|
| setattr(trainer, attr, None)
|
| trainer.train()
|
| trainer.save_model(args.output_dir)
|
| tokenizer.save_pretrained(args.output_dir)
|
|
|
| if args.save_merged_16bit:
|
| merged_dir = Path(args.output_dir) / "merged_16bit"
|
| try:
|
| model.save_pretrained_merged(
|
| str(merged_dir),
|
| tokenizer,
|
| save_method="merged_16bit",
|
| )
|
| print(f"Saved merged 16-bit model to {merged_dir}")
|
| except AttributeError:
|
| print("Merged 16-bit export is not available in this Unsloth build; skipping.")
|
|
|
| if args.push_to_hub:
|
| from huggingface_hub import HfApi
|
|
|
| api = HfApi()
|
| api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True)
|
| print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
|
| api.upload_folder(
|
| folder_path=args.output_dir,
|
| repo_id=args.push_to_hub,
|
| repo_type="model",
|
| create_pr=False,
|
| )
|
| print(f"Model pushed to https://huggingface.co/{args.push_to_hub}")
|
|
|
| plot_paths = base.save_training_plots(
|
| trainer.state.log_history,
|
| args.output_dir,
|
| metric_key=args.plot_metric_key,
|
| )
|
| print("Saved training plots:")
|
| for plot_name, plot_path in plot_paths.items():
|
| print(f" - {plot_name}: {plot_path}")
|
|
|
| return {
|
| "args": args,
|
| "runtime": runtime,
|
| "scenario_names": scenario_names,
|
| "examples": examples,
|
| "reward_fn": env_reward,
|
| "train_dataset": train_dataset,
|
| "tokenizer": tokenizer,
|
| "model": model,
|
| "trainer": trainer,
|
| "plot_paths": plot_paths,
|
| }
|
|
|
|
|
| def main() -> None:
|
| run_training(parse_args())
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|