| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from dataclasses import asdict, dataclass |
| from datetime import UTC, datetime |
| from pathlib import Path |
| from typing import Final |
|
|
| import gymnasium as gym |
| import numpy as np |
| from gymnasium import spaces |
| from stable_baselines3 import PPO |
|
|
| from fusion_lab.models import StellaratorAction, StellaratorObservation |
| from server.contract import RESET_SEEDS |
| from server.environment import BUDGET, StellaratorEnvironment |
|
|
| DEFAULT_OUTPUT_DIR: Final[Path] = Path("training/artifacts/ppo_smoke") |
| DEFAULT_TOTAL_TIMESTEPS: Final[int] = 32 |
| DEFAULT_EVAL_EPISODES: Final[int] = 3 |
| ENCODED_OBSERVATION_DIM: Final[int] = 17 |
|
|
| DIAGNOSTIC_RUN_ACTION_SPECS: Final[tuple[tuple[str, str, str], ...]] = ( |
| ("rotational_transform", "increase", "medium"), |
| ("triangularity_scale", "increase", "medium"), |
| ) |
| TRAIN_RESET_SEED_INDICES: Final[tuple[int, ...]] = (2,) |
| LOW_FI_ACTION_COUNT: Final[int] = len(DIAGNOSTIC_RUN_ACTION_SPECS) |
|
|
|
|
| @dataclass(frozen=True) |
| class TraceStep: |
| step: int |
| action_index: int |
| action_label: str |
| reward: float |
| score: float |
| feasibility: float |
| constraints_satisfied: bool |
| evaluation_failed: bool |
| budget_remaining: int |
| termination_reason: str |
| max_elongation: float |
| average_triangularity: float |
| edge_iota_over_nfp: float |
|
|
|
|
| @dataclass(frozen=True) |
| class EpisodeTrace: |
| episode: int |
| seed: int |
| total_reward: float |
| final_score: float |
| final_feasibility: float |
| constraints_satisfied: bool |
| evaluation_failed: bool |
| termination_reason: str |
| steps: list[TraceStep] |
|
|
|
|
| def diagnostic_action(action_index: int) -> StellaratorAction: |
| parameter, direction, magnitude = DIAGNOSTIC_RUN_ACTION_SPECS[action_index] |
| return StellaratorAction( |
| intent="run", |
| parameter=parameter, |
| direction=direction, |
| magnitude=magnitude, |
| ) |
|
|
|
|
| def diagnostic_action_label(action_index: int) -> str: |
| action = diagnostic_action(action_index) |
| return f"{action.parameter} {action.direction} {action.magnitude}" |
|
|
|
|
| class LowFiSmokeEnv(gym.Env[np.ndarray, int]): |
| metadata = {"render_modes": []} |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| self._env = StellaratorEnvironment() |
| self._seed = 0 |
| self._episode_index = 0 |
| self.observation_space = spaces.Box( |
| low=-np.inf, |
| high=np.inf, |
| |
| shape=(ENCODED_OBSERVATION_DIM,), |
| dtype=np.float32, |
| ) |
| self.action_space = spaces.Discrete(LOW_FI_ACTION_COUNT) |
|
|
| def reset( |
| self, |
| *, |
| seed: int | None = None, |
| options: dict[str, object] | None = None, |
| ) -> tuple[np.ndarray, dict[str, object]]: |
| super().reset(seed=seed) |
| self._seed = self._next_seed(seed) |
| obs = self._env.reset(seed=self._seed) |
| return self._encode_observation(obs), self._info(obs) |
|
|
| def _next_seed(self, seed: int | None) -> int: |
| if seed is not None: |
| self._episode_index = 0 |
| return seed % len(RESET_SEEDS) |
| next_seed = TRAIN_RESET_SEED_INDICES[self._episode_index % len(TRAIN_RESET_SEED_INDICES)] |
| self._episode_index += 1 |
| return next_seed |
|
|
| def step( |
| self, |
| action: int, |
| ) -> tuple[np.ndarray, float, bool, bool, dict[str, object]]: |
| obs = self._env.step(self._decode_action(action)) |
| terminated = self._is_terminal_observation(obs) |
| return ( |
| self._encode_observation(obs), |
| float(obs.reward if obs.reward is not None else 0.0), |
| terminated, |
| False, |
| self._info(obs), |
| ) |
|
|
| def _decode_action(self, action: int) -> StellaratorAction: |
| return diagnostic_action(action) |
|
|
| def action_label(self, action: int) -> str: |
| return diagnostic_action_label(action) |
|
|
| def _encode_observation(self, obs: StellaratorObservation) -> np.ndarray: |
| params = self._env.state.current_params |
| budget_fraction = obs.budget_remaining / BUDGET |
| step_fraction = obs.step_number / BUDGET |
| return np.array( |
| [ |
| obs.max_elongation, |
| obs.aspect_ratio, |
| obs.average_triangularity, |
| obs.edge_iota_over_nfp, |
| obs.p1_score, |
| obs.p1_feasibility, |
| obs.vacuum_well, |
| params.aspect_ratio, |
| params.elongation, |
| params.rotational_transform, |
| params.triangularity_scale, |
| budget_fraction, |
| step_fraction, |
| obs.best_low_fidelity_score, |
| obs.best_low_fidelity_feasibility, |
| float(obs.constraints_satisfied), |
| float(obs.evaluation_failed), |
| ], |
| dtype=np.float32, |
| ) |
|
|
| def _info(self, obs: StellaratorObservation) -> dict[str, object]: |
| return { |
| "diagnostics_text": obs.diagnostics_text, |
| "budget_remaining": obs.budget_remaining, |
| "constraints_satisfied": obs.constraints_satisfied, |
| "evaluation_failed": obs.evaluation_failed, |
| "p1_score": obs.p1_score, |
| "p1_feasibility": obs.p1_feasibility, |
| "max_elongation": obs.max_elongation, |
| "average_triangularity": obs.average_triangularity, |
| "edge_iota_over_nfp": obs.edge_iota_over_nfp, |
| "termination_reason": self._termination_reason(obs), |
| "current_seed": self._seed, |
| } |
|
|
| def _termination_reason(self, obs: StellaratorObservation) -> str: |
| if obs.evaluation_failed: |
| return "evaluation_failed" |
| if obs.constraints_satisfied: |
| return "constraints_satisfied" |
| if obs.done: |
| return "budget_exhausted" |
| return "in_progress" |
|
|
| def _is_terminal_observation(self, obs: StellaratorObservation) -> bool: |
| return bool(obs.done or obs.constraints_satisfied or obs.evaluation_failed) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Run a tiny low-fidelity PPO smoke pass against the repaired Fusion Design Lab " |
| "environment and save a small trajectory artifact." |
| ) |
| ) |
| parser.add_argument( |
| "--total-timesteps", |
| type=int, |
| default=DEFAULT_TOTAL_TIMESTEPS, |
| help=f"Total PPO timesteps for the smoke run (default: {DEFAULT_TOTAL_TIMESTEPS}).", |
| ) |
| parser.add_argument( |
| "--eval-episodes", |
| type=int, |
| default=DEFAULT_EVAL_EPISODES, |
| help=f"Number of deterministic evaluation episodes to record (default: {DEFAULT_EVAL_EPISODES}).", |
| ) |
| parser.add_argument( |
| "--seed", |
| type=int, |
| default=0, |
| help="Base seed for training and evaluation.", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=Path, |
| default=DEFAULT_OUTPUT_DIR, |
| help="Directory where the JSON artifact should be written.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def build_model(env: LowFiSmokeEnv, seed: int) -> PPO: |
| return PPO( |
| policy="MlpPolicy", |
| env=env, |
| seed=seed, |
| verbose=0, |
| device="cpu", |
| n_steps=16, |
| batch_size=16, |
| n_epochs=8, |
| gamma=0.995, |
| learning_rate=3e-4, |
| ent_coef=0.01, |
| ) |
|
|
|
|
| def evaluate_policy( |
| model: PPO, *, eval_episodes: int, base_seed: int |
| ) -> tuple[list[EpisodeTrace], list[int]]: |
| traces: list[EpisodeTrace] = [] |
| eval_reset_seed_indices: list[int] = [] |
| env = LowFiSmokeEnv() |
| for episode in range(eval_episodes): |
| seed = base_seed + episode |
| eval_reset_seed_indices.append(seed % len(RESET_SEEDS)) |
| obs, info = env.reset(seed=seed) |
| done = False |
| total_reward = 0.0 |
| steps: list[TraceStep] = [] |
| step_index = 0 |
| final_info = dict(info) |
|
|
| while not done: |
| action, _ = model.predict(obs, deterministic=True) |
| action_index = int(action) |
| obs, reward, terminated, truncated, info = env.step(action_index) |
| done = terminated or truncated |
| total_reward += reward |
| step_index += 1 |
| final_info = info |
| steps.append( |
| TraceStep( |
| step=step_index, |
| action_index=action_index, |
| action_label=env.action_label(action_index), |
| reward=reward, |
| score=float(info["p1_score"]), |
| feasibility=float(info["p1_feasibility"]), |
| constraints_satisfied=bool(info["constraints_satisfied"]), |
| evaluation_failed=bool(info["evaluation_failed"]), |
| budget_remaining=int(info["budget_remaining"]), |
| termination_reason=str(info["termination_reason"]), |
| max_elongation=float(info["max_elongation"]), |
| average_triangularity=float(info["average_triangularity"]), |
| edge_iota_over_nfp=float(info["edge_iota_over_nfp"]), |
| ) |
| ) |
|
|
| traces.append( |
| EpisodeTrace( |
| episode=episode, |
| seed=seed, |
| total_reward=round(total_reward, 4), |
| final_score=float(final_info["p1_score"]), |
| final_feasibility=float(final_info["p1_feasibility"]), |
| constraints_satisfied=bool(final_info["constraints_satisfied"]), |
| evaluation_failed=bool(final_info["evaluation_failed"]), |
| termination_reason=str(final_info["termination_reason"]), |
| steps=steps, |
| ) |
| ) |
| return traces, eval_reset_seed_indices |
|
|
|
|
| def artifact_payload( |
| *, |
| total_timesteps: int, |
| eval_episodes: int, |
| seed: int, |
| eval_reset_seed_indices: list[int], |
| traces: list[EpisodeTrace], |
| ) -> dict[str, object]: |
| mean_reward = sum(trace.total_reward for trace in traces) / max(len(traces), 1) |
| success_rate = sum(1 for trace in traces if trace.constraints_satisfied) / max(len(traces), 1) |
| return { |
| "created_at_utc": datetime.now(UTC).isoformat(), |
| "mode": "low_fidelity_ppo_smoke", |
| "total_timesteps": total_timesteps, |
| "eval_episodes": eval_episodes, |
| "seed": seed, |
| "train_reset_seed_indices": list(TRAIN_RESET_SEED_INDICES), |
| "eval_reset_seed_indices": eval_reset_seed_indices, |
| "action_space_size": LOW_FI_ACTION_COUNT, |
| "diagnostic_run_actions": [ |
| diagnostic_action_label(action_index) for action_index in range(LOW_FI_ACTION_COUNT) |
| ], |
| "notes": ( |
| "Diagnostics-only low-fidelity PPO smoke; submit is excluded and the action " |
| "space is narrowed to a two-step repair arc. Evaluation runs across " |
| "frozen seeds and records full low-fi traces." |
| ), |
| "summary": { |
| "mean_eval_reward": round(mean_reward, 4), |
| "constraint_satisfaction_rate": round(success_rate, 4), |
| }, |
| "episodes": [asdict(trace) for trace in traces], |
| } |
|
|
|
|
| def write_artifact(output_dir: Path, payload: dict[str, object]) -> Path: |
| output_dir.mkdir(parents=True, exist_ok=True) |
| timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") |
| output_path = output_dir / f"ppo_smoke_{timestamp}.json" |
| output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n") |
| return output_path |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| env = LowFiSmokeEnv() |
| model = build_model(env, seed=args.seed) |
| model.learn(total_timesteps=args.total_timesteps, progress_bar=False) |
| traces, eval_reset_seed_indices = evaluate_policy( |
| model, |
| eval_episodes=args.eval_episodes, |
| base_seed=args.seed, |
| ) |
| payload = artifact_payload( |
| total_timesteps=args.total_timesteps, |
| eval_episodes=args.eval_episodes, |
| seed=args.seed, |
| eval_reset_seed_indices=eval_reset_seed_indices, |
| traces=traces, |
| ) |
| output_path = write_artifact(args.output_dir, payload) |
| summary = payload["summary"] |
| print(output_path) |
| print(f"constraint_satisfaction_rate={summary['constraint_satisfaction_rate']}") |
| print(f"mean_eval_reward={summary['mean_eval_reward']}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|