fusion-design-lab / training /ppo_smoke.py
CreativeEngineer's picture
fix: restore ppo smoke early termination
9827b11
from __future__ import annotations
import argparse
import json
from dataclasses import asdict, dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Final
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3 import PPO
from fusion_lab.models import StellaratorAction, StellaratorObservation
from server.contract import RESET_SEEDS
from server.environment import BUDGET, StellaratorEnvironment
DEFAULT_OUTPUT_DIR: Final[Path] = Path("training/artifacts/ppo_smoke")
DEFAULT_TOTAL_TIMESTEPS: Final[int] = 32
DEFAULT_EVAL_EPISODES: Final[int] = 3
ENCODED_OBSERVATION_DIM: Final[int] = 17
DIAGNOSTIC_RUN_ACTION_SPECS: Final[tuple[tuple[str, str, str], ...]] = (
("rotational_transform", "increase", "medium"),
("triangularity_scale", "increase", "medium"),
)
TRAIN_RESET_SEED_INDICES: Final[tuple[int, ...]] = (2,)
LOW_FI_ACTION_COUNT: Final[int] = len(DIAGNOSTIC_RUN_ACTION_SPECS)
@dataclass(frozen=True)
class TraceStep:
step: int
action_index: int
action_label: str
reward: float
score: float
feasibility: float
constraints_satisfied: bool
evaluation_failed: bool
budget_remaining: int
termination_reason: str
max_elongation: float
average_triangularity: float
edge_iota_over_nfp: float
@dataclass(frozen=True)
class EpisodeTrace:
episode: int
seed: int
total_reward: float
final_score: float
final_feasibility: float
constraints_satisfied: bool
evaluation_failed: bool
termination_reason: str
steps: list[TraceStep]
def diagnostic_action(action_index: int) -> StellaratorAction:
parameter, direction, magnitude = DIAGNOSTIC_RUN_ACTION_SPECS[action_index]
return StellaratorAction(
intent="run",
parameter=parameter,
direction=direction,
magnitude=magnitude,
)
def diagnostic_action_label(action_index: int) -> str:
action = diagnostic_action(action_index)
return f"{action.parameter} {action.direction} {action.magnitude}"
class LowFiSmokeEnv(gym.Env[np.ndarray, int]):
metadata = {"render_modes": []}
def __init__(self) -> None:
super().__init__()
self._env = StellaratorEnvironment()
self._seed = 0
self._episode_index = 0
self.observation_space = spaces.Box(
low=-np.inf,
high=np.inf,
# Keep this aligned with _encode_observation feature count.
shape=(ENCODED_OBSERVATION_DIM,),
dtype=np.float32,
)
self.action_space = spaces.Discrete(LOW_FI_ACTION_COUNT)
def reset(
self,
*,
seed: int | None = None,
options: dict[str, object] | None = None,
) -> tuple[np.ndarray, dict[str, object]]:
super().reset(seed=seed)
self._seed = self._next_seed(seed)
obs = self._env.reset(seed=self._seed)
return self._encode_observation(obs), self._info(obs)
def _next_seed(self, seed: int | None) -> int:
if seed is not None:
self._episode_index = 0
return seed % len(RESET_SEEDS)
next_seed = TRAIN_RESET_SEED_INDICES[self._episode_index % len(TRAIN_RESET_SEED_INDICES)]
self._episode_index += 1
return next_seed
def step(
self,
action: int,
) -> tuple[np.ndarray, float, bool, bool, dict[str, object]]:
obs = self._env.step(self._decode_action(action))
terminated = self._is_terminal_observation(obs)
return (
self._encode_observation(obs),
float(obs.reward if obs.reward is not None else 0.0),
terminated,
False,
self._info(obs),
)
def _decode_action(self, action: int) -> StellaratorAction:
return diagnostic_action(action)
def action_label(self, action: int) -> str:
return diagnostic_action_label(action)
def _encode_observation(self, obs: StellaratorObservation) -> np.ndarray:
params = self._env.state.current_params
budget_fraction = obs.budget_remaining / BUDGET
step_fraction = obs.step_number / BUDGET
return np.array(
[
obs.max_elongation,
obs.aspect_ratio,
obs.average_triangularity,
obs.edge_iota_over_nfp,
obs.p1_score,
obs.p1_feasibility,
obs.vacuum_well,
params.aspect_ratio,
params.elongation,
params.rotational_transform,
params.triangularity_scale,
budget_fraction,
step_fraction,
obs.best_low_fidelity_score,
obs.best_low_fidelity_feasibility,
float(obs.constraints_satisfied),
float(obs.evaluation_failed),
],
dtype=np.float32,
)
def _info(self, obs: StellaratorObservation) -> dict[str, object]:
return {
"diagnostics_text": obs.diagnostics_text,
"budget_remaining": obs.budget_remaining,
"constraints_satisfied": obs.constraints_satisfied,
"evaluation_failed": obs.evaluation_failed,
"p1_score": obs.p1_score,
"p1_feasibility": obs.p1_feasibility,
"max_elongation": obs.max_elongation,
"average_triangularity": obs.average_triangularity,
"edge_iota_over_nfp": obs.edge_iota_over_nfp,
"termination_reason": self._termination_reason(obs),
"current_seed": self._seed,
}
def _termination_reason(self, obs: StellaratorObservation) -> str:
if obs.evaluation_failed:
return "evaluation_failed"
if obs.constraints_satisfied:
return "constraints_satisfied"
if obs.done:
return "budget_exhausted"
return "in_progress"
def _is_terminal_observation(self, obs: StellaratorObservation) -> bool:
return bool(obs.done or obs.constraints_satisfied or obs.evaluation_failed)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Run a tiny low-fidelity PPO smoke pass against the repaired Fusion Design Lab "
"environment and save a small trajectory artifact."
)
)
parser.add_argument(
"--total-timesteps",
type=int,
default=DEFAULT_TOTAL_TIMESTEPS,
help=f"Total PPO timesteps for the smoke run (default: {DEFAULT_TOTAL_TIMESTEPS}).",
)
parser.add_argument(
"--eval-episodes",
type=int,
default=DEFAULT_EVAL_EPISODES,
help=f"Number of deterministic evaluation episodes to record (default: {DEFAULT_EVAL_EPISODES}).",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Base seed for training and evaluation.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=DEFAULT_OUTPUT_DIR,
help="Directory where the JSON artifact should be written.",
)
return parser.parse_args()
def build_model(env: LowFiSmokeEnv, seed: int) -> PPO:
return PPO(
policy="MlpPolicy",
env=env,
seed=seed,
verbose=0,
device="cpu",
n_steps=16,
batch_size=16,
n_epochs=8,
gamma=0.995,
learning_rate=3e-4,
ent_coef=0.01,
)
def evaluate_policy(
model: PPO, *, eval_episodes: int, base_seed: int
) -> tuple[list[EpisodeTrace], list[int]]:
traces: list[EpisodeTrace] = []
eval_reset_seed_indices: list[int] = []
env = LowFiSmokeEnv()
for episode in range(eval_episodes):
seed = base_seed + episode
eval_reset_seed_indices.append(seed % len(RESET_SEEDS))
obs, info = env.reset(seed=seed)
done = False
total_reward = 0.0
steps: list[TraceStep] = []
step_index = 0
final_info = dict(info)
while not done:
action, _ = model.predict(obs, deterministic=True)
action_index = int(action)
obs, reward, terminated, truncated, info = env.step(action_index)
done = terminated or truncated
total_reward += reward
step_index += 1
final_info = info
steps.append(
TraceStep(
step=step_index,
action_index=action_index,
action_label=env.action_label(action_index),
reward=reward,
score=float(info["p1_score"]),
feasibility=float(info["p1_feasibility"]),
constraints_satisfied=bool(info["constraints_satisfied"]),
evaluation_failed=bool(info["evaluation_failed"]),
budget_remaining=int(info["budget_remaining"]),
termination_reason=str(info["termination_reason"]),
max_elongation=float(info["max_elongation"]),
average_triangularity=float(info["average_triangularity"]),
edge_iota_over_nfp=float(info["edge_iota_over_nfp"]),
)
)
traces.append(
EpisodeTrace(
episode=episode,
seed=seed,
total_reward=round(total_reward, 4),
final_score=float(final_info["p1_score"]),
final_feasibility=float(final_info["p1_feasibility"]),
constraints_satisfied=bool(final_info["constraints_satisfied"]),
evaluation_failed=bool(final_info["evaluation_failed"]),
termination_reason=str(final_info["termination_reason"]),
steps=steps,
)
)
return traces, eval_reset_seed_indices
def artifact_payload(
*,
total_timesteps: int,
eval_episodes: int,
seed: int,
eval_reset_seed_indices: list[int],
traces: list[EpisodeTrace],
) -> dict[str, object]:
mean_reward = sum(trace.total_reward for trace in traces) / max(len(traces), 1)
success_rate = sum(1 for trace in traces if trace.constraints_satisfied) / max(len(traces), 1)
return {
"created_at_utc": datetime.now(UTC).isoformat(),
"mode": "low_fidelity_ppo_smoke",
"total_timesteps": total_timesteps,
"eval_episodes": eval_episodes,
"seed": seed,
"train_reset_seed_indices": list(TRAIN_RESET_SEED_INDICES),
"eval_reset_seed_indices": eval_reset_seed_indices,
"action_space_size": LOW_FI_ACTION_COUNT,
"diagnostic_run_actions": [
diagnostic_action_label(action_index) for action_index in range(LOW_FI_ACTION_COUNT)
],
"notes": (
"Diagnostics-only low-fidelity PPO smoke; submit is excluded and the action "
"space is narrowed to a two-step repair arc. Evaluation runs across "
"frozen seeds and records full low-fi traces."
),
"summary": {
"mean_eval_reward": round(mean_reward, 4),
"constraint_satisfaction_rate": round(success_rate, 4),
},
"episodes": [asdict(trace) for trace in traces],
}
def write_artifact(output_dir: Path, payload: dict[str, object]) -> Path:
output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
output_path = output_dir / f"ppo_smoke_{timestamp}.json"
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n")
return output_path
def main() -> None:
args = parse_args()
env = LowFiSmokeEnv()
model = build_model(env, seed=args.seed)
model.learn(total_timesteps=args.total_timesteps, progress_bar=False)
traces, eval_reset_seed_indices = evaluate_policy(
model,
eval_episodes=args.eval_episodes,
base_seed=args.seed,
)
payload = artifact_payload(
total_timesteps=args.total_timesteps,
eval_episodes=args.eval_episodes,
seed=args.seed,
eval_reset_seed_indices=eval_reset_seed_indices,
traces=traces,
)
output_path = write_artifact(args.output_dir, payload)
summary = payload["summary"]
print(output_path)
print(f"constraint_satisfaction_rate={summary['constraint_satisfaction_rate']}")
print(f"mean_eval_reward={summary['mean_eval_reward']}")
if __name__ == "__main__":
main()