| |
|
|
| import contextlib |
| import dataclasses |
| import datetime |
| import faulthandler |
| import os |
| import signal |
| import time |
| from moviepy.editor import ImageSequenceClip |
| import numpy as np |
| from openpi_client import image_tools |
| from openpi_client import websocket_client_policy |
| import pandas as pd |
| from PIL import Image |
| from droid.robot_env import RobotEnv |
| import tqdm |
| import tyro |
|
|
| faulthandler.enable() |
|
|
| |
| DROID_CONTROL_FREQUENCY = 15 |
|
|
|
|
| @dataclasses.dataclass |
| class Args: |
| |
| left_camera_id: str = "<your_camera_id>" |
| right_camera_id: str = "<your_camera_id>" |
| wrist_camera_id: str = "<your_camera_id>" |
|
|
| |
| external_camera: str | None = ( |
| None |
| ) |
|
|
| |
| max_timesteps: int = 600 |
| |
| |
| open_loop_horizon: int = 8 |
|
|
| |
| remote_host: str = "0.0.0.0" |
| remote_port: int = ( |
| 8000 |
| ) |
|
|
|
|
| |
| |
| |
| @contextlib.contextmanager |
| def prevent_keyboard_interrupt(): |
| """Temporarily prevent keyboard interrupts by delaying them until after the protected code.""" |
| interrupted = False |
| original_handler = signal.getsignal(signal.SIGINT) |
|
|
| def handler(signum, frame): |
| nonlocal interrupted |
| interrupted = True |
|
|
| signal.signal(signal.SIGINT, handler) |
| try: |
| yield |
| finally: |
| signal.signal(signal.SIGINT, original_handler) |
| if interrupted: |
| raise KeyboardInterrupt |
|
|
|
|
| def main(args: Args): |
| |
| assert ( |
| args.external_camera is not None and args.external_camera in ["left", "right"] |
| ), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}" |
|
|
| |
| env = RobotEnv(action_space="joint_velocity", gripper_action_space="position") |
| print("Created the droid env!") |
|
|
| |
| policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port) |
|
|
| df = pd.DataFrame(columns=["success", "duration", "video_filename"]) |
|
|
| while True: |
| instruction = input("Enter instruction: ") |
|
|
| |
| actions_from_chunk_completed = 0 |
| pred_action_chunk = None |
|
|
| |
| timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S") |
| video = [] |
| bar = tqdm.tqdm(range(args.max_timesteps)) |
| print("Running rollout... press Ctrl+C to stop early.") |
| for t_step in bar: |
| start_time = time.time() |
| try: |
| |
| curr_obs = _extract_observation( |
| args, |
| env.get_observation(), |
| |
| save_to_disk=t_step == 0, |
| ) |
|
|
| video.append(curr_obs[f"{args.external_camera}_image"]) |
|
|
| |
| if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon: |
| actions_from_chunk_completed = 0 |
|
|
| |
| |
| request_data = { |
| "observation/exterior_image_1_left": image_tools.resize_with_pad( |
| curr_obs[f"{args.external_camera}_image"], 224, 224 |
| ), |
| "observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224), |
| "observation/joint_position": curr_obs["joint_position"], |
| "observation/gripper_position": curr_obs["gripper_position"], |
| "prompt": instruction, |
| } |
|
|
| |
| |
| with prevent_keyboard_interrupt(): |
| |
| pred_action_chunk = policy_client.infer(request_data)["actions"] |
| assert pred_action_chunk.shape == (10, 8) |
|
|
| |
| action = pred_action_chunk[actions_from_chunk_completed] |
| actions_from_chunk_completed += 1 |
|
|
| |
| if action[-1].item() > 0.5: |
| |
| action = np.concatenate([action[:-1], np.ones((1,))]) |
| else: |
| |
| action = np.concatenate([action[:-1], np.zeros((1,))]) |
|
|
| |
| action = np.clip(action, -1, 1) |
|
|
| env.step(action) |
|
|
| |
| elapsed_time = time.time() - start_time |
| if elapsed_time < 1 / DROID_CONTROL_FREQUENCY: |
| time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time) |
| except KeyboardInterrupt: |
| break |
|
|
| video = np.stack(video) |
| save_filename = "video_" + timestamp |
| ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264") |
|
|
| success: str | float | None = None |
| while not isinstance(success, float): |
| success = input( |
| "Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec" |
| ) |
| if success == "y": |
| success = 1.0 |
| elif success == "n": |
| success = 0.0 |
|
|
| success = float(success) / 100 |
| if not (0 <= success <= 1): |
| print(f"Success must be a number in [0, 100] but got: {success * 100}") |
|
|
| df = df.append( |
| { |
| "success": success, |
| "duration": t_step, |
| "video_filename": save_filename, |
| }, |
| ignore_index=True, |
| ) |
|
|
| if input("Do one more eval? (enter y or n) ").lower() != "y": |
| break |
| env.reset() |
|
|
| os.makedirs("results", exist_ok=True) |
| timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y") |
| csv_filename = os.path.join("results", f"eval_{timestamp}.csv") |
| df.to_csv(csv_filename) |
| print(f"Results saved to {csv_filename}") |
|
|
|
|
| def _extract_observation(args: Args, obs_dict, *, save_to_disk=False): |
| image_observations = obs_dict["image"] |
| left_image, right_image, wrist_image = None, None, None |
| for key in image_observations: |
| |
| |
| if args.left_camera_id in key and "left" in key: |
| left_image = image_observations[key] |
| elif args.right_camera_id in key and "left" in key: |
| right_image = image_observations[key] |
| elif args.wrist_camera_id in key and "left" in key: |
| wrist_image = image_observations[key] |
|
|
| |
| left_image = left_image[..., :3] |
| right_image = right_image[..., :3] |
| wrist_image = wrist_image[..., :3] |
|
|
| |
| left_image = left_image[..., ::-1] |
| right_image = right_image[..., ::-1] |
| wrist_image = wrist_image[..., ::-1] |
|
|
| |
| robot_state = obs_dict["robot_state"] |
| cartesian_position = np.array(robot_state["cartesian_position"]) |
| joint_position = np.array(robot_state["joint_positions"]) |
| gripper_position = np.array([robot_state["gripper_position"]]) |
|
|
| |
| |
| if save_to_disk: |
| combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1) |
| combined_image = Image.fromarray(combined_image) |
| combined_image.save("robot_camera_views.png") |
|
|
| return { |
| "left_image": left_image, |
| "right_image": right_image, |
| "wrist_image": wrist_image, |
| "cartesian_position": cartesian_position, |
| "joint_position": joint_position, |
| "gripper_position": gripper_position, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| args: Args = tyro.cli(Args) |
| main(args) |
|
|