| """ |
| Run a single benchmark episode and save the rollout as a video. |
| |
| Use this script to sanity-check the environment and action space |
| """ |
|
|
| from pathlib import Path |
| from typing import Literal |
|
|
| import cv2 |
| import imageio |
| import numpy as np |
| import torch |
| import tyro |
|
|
| from robomme.env_record_wrapper import BenchmarkEnvBuilder |
| from robomme.robomme_env.utils import generate_sample_actions |
|
|
| GUI_RENDER = False |
| VIDEO_FPS = 30 |
| VIDEO_OUTPUT_DIR = "sample_run_videos" |
| MAX_STEPS = 300 |
| EPISODE_LIMITS = {"train": 100, "test": 50, "val": 50} |
| VIDEO_BORDER_COLOR = (255, 0, 0) |
| VIDEO_BORDER_THICKNESS = 10 |
| TaskID = Literal[ |
| "BinFill", "PickXtimes", "SwingXtimes", "StopCube", |
| "VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap", |
| "PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder", |
| "MoveCube", "InsertPeg", "PatternLock", "RouteStick", |
| "All", |
| ] |
| ActionSpaceType = Literal["joint_angle", "ee_pose", "waypoint", "multi_choice"] |
| DatasetType = Literal["train", "test", "val"] |
|
|
|
|
| def _to_numpy(t) -> np.ndarray: |
| return t.cpu().numpy() if isinstance(t, torch.Tensor) else np.asarray(t) |
|
|
|
|
| def _frame_from_obs( |
| front: np.ndarray | torch.Tensor, |
| wrist: np.ndarray | torch.Tensor, |
| is_video_demo: bool = False, |
| ) -> np.ndarray: |
| frame = np.hstack([_to_numpy(front), _to_numpy(wrist)]).astype(np.uint8) |
| if is_video_demo: |
| h, w = frame.shape[:2] |
| cv2.rectangle(frame, (0, 0), (w, h), |
| VIDEO_BORDER_COLOR, VIDEO_BORDER_THICKNESS) |
| return frame |
|
|
|
|
| def _validate_episode_index(episode_idx: int, dataset: DatasetType) -> None: |
| if episode_idx == -1: |
| return |
| limit = EPISODE_LIMITS[dataset] |
| if not 0 <= episode_idx < limit: |
| raise ValueError( |
| f"Invalid episode_idx {episode_idx} for '{dataset}'; allowed: [0, {limit})" |
| ) |
|
|
|
|
| def _save_video( |
| frames: list[np.ndarray], |
| task_id: str, |
| episode_idx: int, |
| action_space_type: str, |
| task_goal: str, |
| ) -> Path: |
| video_dir = Path(VIDEO_OUTPUT_DIR) / action_space_type |
| video_dir.mkdir(parents=True, exist_ok=True) |
| path = video_dir / f"{task_id}_ep{episode_idx}_{task_goal}.mp4" |
| imageio.mimsave(str(path), frames, fps=VIDEO_FPS) |
| return path |
|
|
|
|
| def _extract_frames(obs: dict, is_video_demo_fn=None) -> list[np.ndarray]: |
| n = len(obs["front_rgb_list"]) |
| return [ |
| _frame_from_obs( |
| obs["front_rgb_list"][i], |
| obs["wrist_rgb_list"][i], |
| is_video_demo=(is_video_demo_fn(i) if is_video_demo_fn else False), |
| ) |
| for i in range(n) |
| ] |
|
|
|
|
|
|
| |
| |
| |
| def main( |
| dataset: DatasetType = "test", |
| task_id: TaskID = "PickXtimes", |
| action_space_type: ActionSpaceType = "joint_angle", |
| episode_idx: int = 0, |
| ) -> None: |
| """ |
| Run a single benchmark episode and save the rollout as a video. |
| |
| Args: |
| action_space_type: Type of action space to use. |
| dataset: Dataset split (train / test / val). |
| task_id: Task identifier, or "All" to run every task. |
| episode_idx: Episode index (-1 = All episodes). |
| """ |
| task_ids = ( |
| BenchmarkEnvBuilder.get_task_list() if task_id == "All" else [task_id] |
| ) |
| _validate_episode_index(episode_idx, dataset) |
|
|
| for tid in task_ids: |
| env_builder = BenchmarkEnvBuilder( |
| env_id=tid, |
| dataset=dataset, |
| action_space=action_space_type, |
| gui_render=GUI_RENDER, |
| max_steps=MAX_STEPS, |
| ) |
| episodes = ( |
| list(range(env_builder.get_episode_num())) |
| if episode_idx == -1 |
| else [episode_idx] |
| ) |
|
|
| for ep in episodes: |
| print(f"\nRunning task: {tid}, episode: {ep}, action_space: {action_space_type}, dataset: {dataset}") |
| env = env_builder.make_env_for_episode( |
| ep, |
| include_maniskill_obs=True, |
| include_front_depth=True, |
| include_wrist_depth=True, |
| include_front_camera_extrinsic=True, |
| include_wrist_camera_extrinsic=True, |
| include_available_multi_choices=True, |
| include_front_camera_intrinsic=True, |
| include_wrist_camera_intrinsic=True, |
| ) |
| obs, info = env.reset() |
|
|
| task_goal = info["task_goal"] |
| if isinstance(task_goal, list): |
| task_goal = task_goal[0] |
| print(f"Task goal: {task_goal}") |
|
|
| frames = _extract_frames( |
| obs, is_video_demo_fn=lambda i, n=len(obs["front_rgb_list"]): i < n - 1 |
| ) |
|
|
| action_gen = generate_sample_actions(action_space_type, env=env) |
| for action in action_gen: |
| obs, _, terminated, truncated, info = env.step(action) |
| status = info.get("status", "unknown") |
| if status == "error": |
| print(f"Step error: {info.get('error_message', 'unknown error')}") |
| break |
| frames.extend(_extract_frames(obs)) |
|
|
| if GUI_RENDER: |
| env.render() |
| if terminated or truncated: |
| print(f"Outcome: {status}") |
| break |
|
|
| env.close() |
| path = _save_video(frames, tid, ep, action_space_type, task_goal) |
| print(f"Saved video: {path}\n") |
|
|
|
|
| if __name__ == "__main__": |
| tyro.cli(main) |
|
|