| """Validation utilities for high-fidelity fixture pairing and submit-side traces.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from dataclasses import asdict, dataclass |
| from datetime import UTC, datetime |
| from pathlib import Path |
| from pprint import pformat |
| from time import perf_counter |
| from typing import Any |
|
|
| from fusion_lab.models import LowDimBoundaryParams, StellaratorAction |
| from server.contract import N_FIELD_PERIODS |
| from server.environment import StellaratorEnvironment |
| from server.physics import EvaluationMetrics, build_boundary_from_params, evaluate_boundary |
|
|
|
|
| LOW_FIDELITY_TOLERANCE = 1.0e-6 |
|
|
|
|
| def _float(value: Any) -> float | None: |
| if isinstance(value, bool): |
| return None |
| try: |
| return float(value) |
| except (TypeError, ValueError): |
| return None |
|
|
|
|
| @dataclass(frozen=True) |
| class FixturePairResult: |
| name: str |
| file: str |
| status: str |
| low_fidelity: dict[str, Any] |
| high_fidelity: dict[str, Any] |
| comparison: dict[str, Any] |
|
|
|
|
| @dataclass(frozen=True) |
| class TraceStep: |
| step: int |
| intent: str |
| action: str |
| reward: float |
| score: float |
| feasibility: float |
| constraints_satisfied: bool |
| feasibility_delta: float | None |
| score_delta: float | None |
| max_elongation: float |
| p1_feasibility: float |
| budget_remaining: int |
| evaluation_fidelity: str |
| done: bool |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Run paired high-fidelity fixture checks and a submit-side manual trace " |
| "for the repaired P1 contract." |
| ) |
| ) |
| parser.add_argument( |
| "--fixture-dir", |
| type=Path, |
| default=Path("server/data/p1"), |
| help="Directory containing tracked P1 fixture JSON files.", |
| ) |
| parser.add_argument( |
| "--fixture-output", |
| type=Path, |
| default=Path("baselines/fixture_high_fidelity_pairs.json"), |
| help="Output path for paired fixture summary JSON.", |
| ) |
| parser.add_argument( |
| "--trace-output", |
| type=Path, |
| default=Path("baselines/submit_side_trace.json"), |
| help="Output path for one submit-side manual trace JSON.", |
| ) |
| parser.add_argument( |
| "--no-write-fixture-updates", |
| action="store_true", |
| help="Do not write paired high-fidelity results back into fixture files.", |
| ) |
| parser.add_argument( |
| "--skip-submit-trace", |
| action="store_true", |
| help="Only run paired fixture checks.", |
| ) |
| parser.add_argument( |
| "--seed", |
| type=int, |
| default=0, |
| help="Seed for the submit-side manual trace reset state.", |
| ) |
| parser.add_argument( |
| "--submit-action-sequence", |
| type=str, |
| default=( |
| "run:rotational_transform:increase:medium," |
| "run:triangularity_scale:increase:medium," |
| "run:elongation:decrease:small," |
| "submit" |
| ), |
| help=( |
| "Comma-separated submit trace sequence. " |
| "Run actions are `run:parameter:direction:magnitude`; include `submit` as the last token." |
| ), |
| ) |
| return parser.parse_args() |
|
|
|
|
| def _fixture_files(fixture_dir: Path) -> list[Path]: |
| return sorted(path for path in fixture_dir.glob("*.json") if path.is_file()) |
|
|
|
|
| def _load_fixture(path: Path) -> dict[str, Any]: |
| with path.open("r") as file: |
| return json.load(file) |
|
|
|
|
| def _metrics_payload(metrics: EvaluationMetrics) -> dict[str, Any]: |
| return { |
| "evaluation_failed": metrics.evaluation_failed, |
| "constraints_satisfied": metrics.constraints_satisfied, |
| "p1_score": metrics.p1_score, |
| "p1_feasibility": metrics.p1_feasibility, |
| "max_elongation": metrics.max_elongation, |
| "aspect_ratio": metrics.aspect_ratio, |
| "average_triangularity": metrics.average_triangularity, |
| "edge_iota_over_nfp": metrics.edge_iota_over_nfp, |
| "vacuum_well": metrics.vacuum_well, |
| "evaluation_fidelity": metrics.evaluation_fidelity, |
| "failure_reason": metrics.failure_reason, |
| } |
|
|
|
|
| def _parse_submit_sequence(raw: str) -> list[StellaratorAction]: |
| actions: list[StellaratorAction] = [] |
| for token in raw.split(","): |
| token = token.strip() |
| if not token: |
| continue |
|
|
| if token == "submit": |
| actions.append(StellaratorAction(intent="submit")) |
| continue |
|
|
| parts = token.split(":") |
| if len(parts) != 4 or parts[0] != "run": |
| raise ValueError( |
| "Expected token format `run:parameter:direction:magnitude` or `submit`." |
| ) |
| _, parameter, direction, magnitude = parts |
| actions.append( |
| StellaratorAction( |
| intent="run", |
| parameter=parameter, |
| direction=direction, |
| magnitude=magnitude, |
| ) |
| ) |
|
|
| if not actions: |
| raise ValueError("submit-action-sequence must include at least one action.") |
| if actions[-1].intent != "submit": |
| raise ValueError("submit-action-sequence must end with submit.") |
| return actions |
|
|
|
|
| def _compare_low_snapshot( |
| stored: dict[str, Any], |
| current: dict[str, Any], |
| ) -> tuple[bool, dict[str, Any]]: |
| numeric_keys = [ |
| "p1_feasibility", |
| "p1_score", |
| "max_elongation", |
| "aspect_ratio", |
| "average_triangularity", |
| "edge_iota_over_nfp", |
| "vacuum_well", |
| ] |
| exact_keys = [ |
| "constraints_satisfied", |
| "evaluation_fidelity", |
| "evaluation_failed", |
| "failure_reason", |
| ] |
| missing_fields: list[str] = [] |
| drift_fields: dict[str, dict[str, float]] = {} |
| mismatches: list[dict[str, Any]] = [] |
| max_abs_drift = 0.0 |
|
|
| for key in numeric_keys: |
| if key not in stored: |
| missing_fields.append(key) |
| continue |
|
|
| expected = _float(stored.get(key)) |
| actual = _float(current.get(key)) |
| if expected is None or actual is None: |
| mismatches.append( |
| { |
| "field": key, |
| "expected": stored.get(key), |
| "actual": current.get(key), |
| "reason": "non-numeric", |
| } |
| ) |
| continue |
|
|
| drift = abs(expected - actual) |
| max_abs_drift = max(max_abs_drift, drift) |
| if drift > LOW_FIDELITY_TOLERANCE: |
| drift_fields[key] = { |
| "expected": expected, |
| "actual": actual, |
| "abs_drift": drift, |
| } |
| mismatches.append( |
| { |
| "field": key, |
| "expected": expected, |
| "actual": actual, |
| "abs_drift": drift, |
| } |
| ) |
|
|
| for key in exact_keys: |
| if key not in stored: |
| missing_fields.append(key) |
| continue |
|
|
| expected = stored.get(key) |
| actual = current.get(key) |
| if expected != actual: |
| mismatches.append( |
| { |
| "field": key, |
| "expected": expected, |
| "actual": actual, |
| "reason": "exact-mismatch", |
| } |
| ) |
|
|
| return ( |
| not missing_fields and not mismatches, |
| { |
| "missing_fields": missing_fields, |
| "drift_fields": drift_fields, |
| "mismatches": mismatches, |
| "max_abs_drift": max_abs_drift, |
| }, |
| ) |
|
|
|
|
| def _pair_fixture(path: Path) -> FixturePairResult: |
| data = _load_fixture(path) |
| params = LowDimBoundaryParams.model_validate(data["params"]) |
| boundary = build_boundary_from_params(params, n_field_periods=N_FIELD_PERIODS) |
|
|
| low = evaluate_boundary(boundary, fidelity="low") |
| high = evaluate_boundary(boundary, fidelity="high") |
|
|
| low_payload = _metrics_payload(low) |
| high_payload = _metrics_payload(high) |
| low_snapshot_ok, low_snapshot = _compare_low_snapshot( |
| data.get("low_fidelity", {}), |
| low_payload, |
| ) |
| feasible_match = low.constraints_satisfied == high.constraints_satisfied |
| ranking_compat = ( |
| "ambiguous" |
| if low.evaluation_failed or high.evaluation_failed |
| else "match" |
| if feasible_match |
| else "mismatch" |
| ) |
|
|
| comparison: dict[str, Any] = { |
| "low_high_feasibility_match": feasible_match, |
| "feasibility_delta": high.p1_feasibility - low.p1_feasibility, |
| "score_delta": high.p1_score - low.p1_score, |
| "ranking_compatibility": ranking_compat, |
| "low_fidelity_stored_p1_score": data.get("low_fidelity", {}).get("p1_score"), |
| "low_fidelity_stored_p1_feasibility": data.get("low_fidelity", {}).get("p1_feasibility"), |
| "low_fidelity_snapshot": low_snapshot, |
| } |
|
|
| status = "pass" |
| if low.evaluation_failed or high.evaluation_failed or not feasible_match or not low_snapshot_ok: |
| status = "fail" |
| if not low_snapshot_ok: |
| print(f" low-fidelity snapshot mismatch:\n{pformat(low_snapshot)}") |
|
|
| return FixturePairResult( |
| name=str(data.get("name", path.stem)), |
| file=str(path), |
| status=status, |
| low_fidelity=low_payload, |
| high_fidelity=high_payload, |
| comparison=comparison, |
| ) |
|
|
|
|
| def _write_json(payload: dict[str, Any], path: Path) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with path.open("w") as file: |
| json.dump(payload, file, indent=2) |
|
|
|
|
| def _run_fixture_checks( |
| *, |
| fixture_dir: Path, |
| fixture_output: Path, |
| write_fixture_updates: bool, |
| ) -> tuple[list[FixturePairResult], int]: |
| results: list[FixturePairResult] = [] |
| fail_count = 0 |
|
|
| for path in _fixture_files(fixture_dir): |
| print(f"Evaluating fixture: {path.name}") |
| fixture_start = perf_counter() |
| result = _pair_fixture(path) |
| if result.status != "pass": |
| fail_count += 1 |
| results.append(result) |
|
|
| if write_fixture_updates: |
| fixture = _load_fixture(path) |
| fixture["high_fidelity"] = result.high_fidelity |
| fixture["paired_high_fidelity_timestamp_utc"] = datetime.now(tz=UTC).isoformat() |
| with path.open("w") as file: |
| json.dump(fixture, file, indent=2) |
|
|
| elapsed = perf_counter() - fixture_start |
| print( |
| " done in " |
| f"{elapsed:0.1f}s | low_feasible={result.low_fidelity['constraints_satisfied']} " |
| f"| high_feasible={result.high_fidelity['constraints_satisfied']} " |
| f"| status={result.status}" |
| ) |
|
|
| pass_count = len(results) - fail_count |
| payload = { |
| "timestamp_utc": datetime.now(tz=UTC).isoformat(), |
| "n_field_periods": N_FIELD_PERIODS, |
| "fixture_count": len(results), |
| "pass_count": pass_count, |
| "fail_count": fail_count, |
| "results": [asdict(result) for result in results], |
| } |
| _write_json(payload, fixture_output) |
| return results, fail_count |
|
|
|
|
| def _run_submit_trace( |
| trace_output: Path, |
| *, |
| seed: int, |
| action_sequence: str, |
| ) -> dict[str, Any]: |
| env = StellaratorEnvironment() |
| obs = env.reset(seed=seed) |
| reset_params = env.state.current_params.model_dump() |
| actions = _parse_submit_sequence(action_sequence) |
|
|
| trace: list[dict[str, Any]] = [ |
| { |
| "step": 0, |
| "intent": "reset", |
| "action": f"reset(seed={seed})", |
| "reward": 0.0, |
| "score": obs.p1_score, |
| "feasibility": obs.p1_feasibility, |
| "feasibility_delta": None, |
| "score_delta": None, |
| "constraints_satisfied": obs.constraints_satisfied, |
| "max_elongation": obs.max_elongation, |
| "p1_feasibility": obs.p1_feasibility, |
| "budget_remaining": obs.budget_remaining, |
| "evaluation_fidelity": obs.evaluation_fidelity, |
| "done": obs.done, |
| "params": reset_params, |
| } |
| ] |
|
|
| previous_feasibility = obs.p1_feasibility |
| previous_score = obs.p1_score |
|
|
| for idx, action in enumerate(actions, start=1): |
| obs = env.step(action) |
| trace.append( |
| asdict( |
| TraceStep( |
| step=idx, |
| intent=action.intent, |
| action=( |
| f"{action.parameter} {action.direction} {action.magnitude}" |
| if action.intent == "run" |
| else action.intent |
| ), |
| reward=float(obs.reward or 0.0), |
| score=obs.p1_score, |
| feasibility=obs.p1_feasibility, |
| constraints_satisfied=obs.constraints_satisfied, |
| feasibility_delta=obs.p1_feasibility - previous_feasibility, |
| score_delta=obs.p1_score - previous_score, |
| max_elongation=obs.max_elongation, |
| p1_feasibility=obs.p1_feasibility, |
| budget_remaining=obs.budget_remaining, |
| evaluation_fidelity=obs.evaluation_fidelity, |
| done=obs.done, |
| ) |
| ) |
| ) |
|
|
| previous_feasibility = obs.p1_feasibility |
| previous_score = obs.p1_score |
| if obs.done: |
| break |
|
|
| total_reward = sum(step["reward"] for step in trace) |
| payload = { |
| "trace_label": "submit_side_manual", |
| "trace_profile": action_sequence, |
| "timestamp_utc": datetime.now(tz=UTC).isoformat(), |
| "n_field_periods": N_FIELD_PERIODS, |
| "seed": seed, |
| "total_reward": total_reward, |
| "final_score": obs.p1_score, |
| "final_feasibility": obs.p1_feasibility, |
| "final_constraints_satisfied": obs.constraints_satisfied, |
| "final_evaluation_fidelity": obs.evaluation_fidelity, |
| "final_evaluation_failed": obs.evaluation_failed, |
| "steps": trace, |
| "final_best_low_fidelity_score": obs.best_low_fidelity_score, |
| "final_best_low_fidelity_feasibility": obs.best_low_fidelity_feasibility, |
| "final_diagnostics_text": obs.diagnostics_text, |
| } |
| _write_json(payload, trace_output) |
| return payload |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| results, fail_count = _run_fixture_checks( |
| fixture_dir=args.fixture_dir, |
| fixture_output=args.fixture_output, |
| write_fixture_updates=not args.no_write_fixture_updates, |
| ) |
|
|
| print( |
| f"Paired fixtures: {len(results)} total, {len(results) - fail_count} pass, {fail_count} fail" |
| ) |
| for result in results: |
| print( |
| f" - {result.name}: {result.status} " |
| f"(low={result.low_fidelity['constraints_satisfied']} " |
| f"high={result.high_fidelity['constraints_satisfied']})" |
| ) |
|
|
| if not args.skip_submit_trace: |
| trace = _run_submit_trace( |
| args.trace_output, |
| seed=args.seed, |
| action_sequence=args.submit_action_sequence, |
| ) |
| print( |
| f"Manual submit trace written to {args.trace_output} | " |
| f"sequence='{trace['trace_profile']}' | " |
| f"final_feasibility={trace['final_feasibility']:.6f} | " |
| f"fidelity={trace['final_evaluation_fidelity']}" |
| ) |
|
|
| return 1 if fail_count else 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|