| |
| |
| |
|
|
| """ |
| Gymnasium Wrapper for RANS |
| =========================== |
| Wraps ``RANSEnvironment`` in a standard ``gymnasium.Env`` interface so any |
| Gymnasium-compatible RL library can be used for training: |
| |
| • Stable-Baselines3 (PPO, SAC, TD3, …) |
| • CleanRL |
| • RLlib |
| • TorchRL |
| |
| The wrapper runs the environment **locally** (in-process) — no HTTP server |
| needed. For server-based training, replace ``RANSEnvironment()`` with the |
| ``RANSEnv`` WebSocket client (see remote_train_sb3.py). |
| |
| Usage |
| ----- |
| # Standalone check |
| python examples/gymnasium_wrapper.py |
| |
| # Stable-Baselines3 PPO (requires: pip install stable-baselines3) |
| from examples.gymnasium_wrapper import make_rans_env |
| from stable_baselines3 import PPO |
| |
| env = make_rans_env(task="GoToPosition") |
| model = PPO("MlpPolicy", env, verbose=1) |
| model.learn(total_timesteps=200_000) |
| model.save("rans_ppo_go_to_position") |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from typing import Any, Dict, Optional, Tuple |
|
|
| import numpy as np |
|
|
| try: |
| import gymnasium as gym |
| from gymnasium import spaces |
| except ImportError: |
| print("gymnasium is required: pip install gymnasium") |
| sys.exit(1) |
|
|
| |
| sys.path.insert(0, __file__.replace("examples/gymnasium_wrapper.py", "")) |
| from server.rans_environment import RANSEnvironment |
| from server.spacecraft_physics import SpacecraftConfig |
| from rans_env.models import SpacecraftAction |
|
|
|
|
| class RANSGymnasiumEnv(gym.Env): |
| """ |
| Gymnasium-compatible wrapper around ``RANSEnvironment``. |
| |
| Observation space: |
| Flat Box containing [state_obs, thruster_transforms (flattened), |
| thruster_masks, mass, inertia]. |
| |
| Action space: |
| Box([0, 1]^n_thrusters) — continuous thruster activations. |
| |
| Parameters |
| ---------- |
| task: |
| RANS task name. |
| spacecraft_config: |
| Physical platform configuration. |
| task_config: |
| Dict of task hyper-parameters. |
| max_episode_steps: |
| Hard step limit per episode. |
| """ |
|
|
| metadata = {"render_modes": []} |
|
|
| def __init__( |
| self, |
| task: str = "GoToPosition", |
| spacecraft_config: Optional[SpacecraftConfig] = None, |
| task_config: Optional[Dict[str, Any]] = None, |
| max_episode_steps: int = 500, |
| ) -> None: |
| super().__init__() |
| self._env = RANSEnvironment( |
| task=task, |
| spacecraft_config=spacecraft_config, |
| task_config=task_config, |
| max_episode_steps=max_episode_steps, |
| ) |
| sc = self._env._spacecraft |
|
|
| |
| n = sc.n_thrusters |
| self.action_space = spaces.Box( |
| low=0.0, high=1.0, shape=(n,), dtype=np.float32 |
| ) |
|
|
| |
| |
| obs0 = self._env.reset() |
| flat_obs = self._flatten(obs0) |
| dim = flat_obs.shape[0] |
| self.observation_space = spaces.Box( |
| low=-np.inf, high=np.inf, shape=(dim,), dtype=np.float32 |
| ) |
|
|
| self._last_obs = flat_obs |
|
|
| |
| |
| |
|
|
| def reset( |
| self, |
| *, |
| seed: Optional[int] = None, |
| options: Optional[Dict] = None, |
| ) -> Tuple[np.ndarray, Dict]: |
| super().reset(seed=seed) |
| obs = self._env.reset() |
| self._last_obs = self._flatten(obs) |
| return self._last_obs, {"task": obs.task} |
|
|
| def step( |
| self, action: np.ndarray |
| ) -> Tuple[np.ndarray, float, bool, bool, Dict]: |
| result = self._env.step( |
| SpacecraftAction(thrusters=action.tolist()) |
| ) |
| flat_obs = self._flatten(result) |
| reward = float(result.reward or 0.0) |
| terminated = bool(result.done) |
| truncated = False |
| self._last_obs = flat_obs |
| return flat_obs, reward, terminated, truncated, result.info or {} |
|
|
| def render(self) -> None: |
| pass |
|
|
| def close(self) -> None: |
| pass |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _flatten(obs) -> np.ndarray: |
| """Flatten the SpacecraftObservation into a 1-D float32 array.""" |
| parts = [ |
| np.array(obs.state_obs, dtype=np.float32), |
| np.array(obs.thruster_transforms, dtype=np.float32).flatten(), |
| np.array(obs.thruster_masks, dtype=np.float32), |
| np.array([obs.mass, obs.inertia], dtype=np.float32), |
| ] |
| return np.concatenate(parts) |
|
|
|
|
| def make_rans_env( |
| task: str = "GoToPosition", |
| task_config: Optional[Dict[str, Any]] = None, |
| max_episode_steps: int = 500, |
| ) -> RANSGymnasiumEnv: |
| """ |
| Factory that returns a ``gymnasium.Env``-compatible RANS environment. |
| |
| Example:: |
| |
| from examples.gymnasium_wrapper import make_rans_env |
| from stable_baselines3 import PPO |
| |
| env = make_rans_env(task="GoToPose") |
| model = PPO("MlpPolicy", env, verbose=1, n_steps=2048) |
| model.learn(total_timesteps=500_000) |
| """ |
| return RANSGymnasiumEnv(task=task, task_config=task_config, |
| max_episode_steps=max_episode_steps) |
|
|
|
|
| |
| |
| |
|
|
| def _smoke_test() -> None: |
| print("RANS Gymnasium Wrapper — smoke test") |
| print("=" * 50) |
|
|
| for task in ["GoToPosition", "GoToPose", |
| "TrackLinearVelocity", "TrackLinearAngularVelocity"]: |
| env = make_rans_env(task=task, max_episode_steps=100) |
| obs, info = env.reset() |
| print(f"\nTask: {task}") |
| print(f" obs shape: {obs.shape}") |
| print(f" action shape: {env.action_space.shape}") |
|
|
| total_reward = 0.0 |
| for _ in range(100): |
| action = env.action_space.sample() |
| obs, reward, terminated, truncated, info = env.step(action) |
| total_reward += reward |
| if terminated or truncated: |
| break |
|
|
| print(f" total_reward: {total_reward:.3f}") |
| print(f" goal_reached: {info.get('goal_reached', False)}") |
| env.close() |
|
|
| print("\nAll tasks OK.") |
|
|
|
|
| if __name__ == "__main__": |
| _smoke_test() |
|
|