| | from node import InferenceNode |
| | import json |
| | import torch |
| | from PIL import Image as IMG |
| | import numpy as np |
| | from std_msgs.msg import String, Bool |
| | import argparse |
| | import h5py |
| | import os, pickle |
| | from einops import rearrange |
| | import numpy as np |
| | from PIL import Image |
| | import time |
| | """ |
| | #!/usr/bin/python3 |
| | """ |
| |
|
| | import argparse |
| | import sys |
| | import threading |
| | import time |
| | import yaml |
| | from collections import deque |
| |
|
| | import numpy as np |
| | import torch |
| | from cv_bridge import CvBridge |
| | from geometry_msgs.msg import Twist |
| | from nav_msgs.msg import Odometry |
| | from std_msgs.msg import Header |
| | import cv2 |
| |
|
| | from scripts.agilex_model import create_model |
| |
|
| | class RDTNode(InferenceNode): |
| | def __init__(self, action_chunk, instruction, ckpt_dir, unnorm_key, hz=20, max_timestep=1000, dataset_name=None, single_arm=True, lang_embed_name=''): |
| | self.ckpt_dir = ckpt_dir |
| | self.lang_embed_name = f'outs/{lang_embed_name}.pt' |
| | self.run_name = f'rdt_{ckpt_dir.split("/")[-1]}' |
| | self.single_arm = single_arm |
| | super().__init__(hz=hz, max_timestep=max_timestep, dataset_name=dataset_name, single_arm=single_arm) |
| | self.obs['language_instruction'] = f'{instruction}' |
| | self.action_chunk = action_chunk |
| | self.action_counter = 0 |
| | self.unnorm_key = unnorm_key |
| | self.prompt_sub = self._node.create_subscription(String, '/vla/prompt', self.prompt_sub, 1) |
| | self.attn = None |
| | |
| |
|
| | def prompt_sub(self, msg): |
| | if self.policy is not None: |
| | img = self.obs['image'] |
| | pil_image = Image.fromarray(img) |
| | print(self.policy.inference_prompt(pil_image, msg.data)) |
| |
|
| | def bringup_model(self): |
| | with open('configs/base.yaml', "r") as fp: |
| | config = yaml.safe_load(fp) |
| | self.policy = create_model( |
| | args=config, |
| | dtype=torch.bfloat16, |
| | pretrained=self.ckpt_dir, |
| | |
| | pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384", |
| | control_frequency=20, |
| | single_arm=self.single_arm |
| | ) |
| | self.lang_embeddings = torch.load(self.lang_embed_name)["embeddings"] |
| |
|
| | def inference_fn(self): |
| | if self.single_arm: |
| | image_arrs = [ |
| | self.frame_buffer[-2], |
| | None, |
| | None, |
| | self.frame_buffer[-1], |
| | None, |
| | None |
| | |
| | ] |
| | else: |
| | image_arrs = [ |
| | self.frame_buffer[-2], |
| | self.left_frame_buffer[-2], |
| | None, |
| | self.frame_buffer[-1], |
| | self.left_frame_buffer[-1], |
| | None |
| | ] |
| | images = [Image.fromarray(arr) if arr is not None else None |
| | for arr in image_arrs] |
| | if self.single_arm: |
| | proprio = torch.tensor(self.joint_pos_buffer[-1][7:]).unsqueeze(0) |
| | else: |
| | proprio = torch.tensor(self.joint_pos_buffer[-1]).unsqueeze(0) |
| |
|
| | actions = self.policy.step( |
| | proprio=proprio, |
| | images=images, |
| | text_embeds=self.lang_embeddings |
| | ).squeeze(0).cpu().numpy() |
| |
|
| | return actions |
| | |
| | def inference(self): |
| | if self.action_counter == 0: |
| | with torch.inference_mode(): |
| | |
| | start_time = time.time() |
| | self.actions = self.inference_fn() |
| | end_time = time.time() |
| | print(f'{end_time - start_time:.6f} sec') |
| | |
| | action = self.actions[self.action_counter] |
| | |
| | if self.single_arm: |
| | self.joint_action(None, action) |
| | else: |
| | self.joint_action(action[:7], action[7:]) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.action_counter += 1 |
| | if self.action_counter == self.action_chunk: |
| | self.action_counter = 0 |
| |
|
| | def done_callback(self, msg): |
| | if not self.start: |
| | |
| | if self.data_list is not None: |
| | root = h5py.File(self.data_list[self.num], 'r') |
| | skip = 5 |
| | if self.single_arm: |
| | self.target_joint_right = root['observation']['joint_pos'][skip, :7] |
| | self.joint_action(None, self.target_joint_right) |
| | else: |
| | self.target_joint_left = root['observation']['joint_pos'][skip, :7] |
| | self.target_joint_right = root['observation']['joint_pos'][skip, 7:] |
| | self.joint_action(self.target_joint_left, self.target_joint_right) |
| | time.sleep(2) |
| | |
| | else: |
| | self.target_ee_left = self.obs['left_pose'] |
| | self.target_ee_right = self.obs['right_pose'] |
| | print('Inference & Video Recording Start') |
| | self.start = True |
| | msg = Bool() |
| | msg.data = True |
| | self.sync_pub.publish(msg) |
| | self.window.video_start() |
| | else: |
| | self.start = False |
| | msg = Bool() |
| | msg.data = False |
| | self.sync_pub.publish(msg) |
| | self.init_robot() |
| | self.action_counter = 0 |
| | if self.window.video_recording: |
| | self.window.video_stop() |
| | self.initialize() |
| | print('Next Inference Ready') |
| |
|
| | if __name__ == "__main__": |
| | import cv2 |
| |
|
| | ckpt_dir = '/home/univ/workspace/rdt-ckpts/checkpoint-38000' |
| |
|
| | action_chunk = 64 |
| | hz = 20 |
| |
|
| | instruction = 'handover the stuffed doll' |
| | unnorm_key = 'handover_kirby' |
| | single_arm = False |
| | dataset_name = [ |
| | 'vla_upright_mug', |
| | 'vla_sweep_screws', |
| | 'vla_pick_ball_place_bin', |
| | 'twinvla_handover_kirby', |
| | 'twinvla_put_bottle', |
| | 'twinvla_detach_ball', |
| | 'twinvla_tear_paper_towel' |
| | ] |
| | lang_embed_name = [ |
| | 'upright_mug', |
| | 'sweep_screws', |
| | 'pick_ball_place_bin', |
| | 'handover_kirby' |
| | ] |
| | num = 3 |
| |
|
| | node = RDTNode( |
| | action_chunk=action_chunk, |
| | instruction=instruction, |
| | ckpt_dir=ckpt_dir, |
| | unnorm_key=unnorm_key, |
| | hz=hz, |
| | max_timestep=1000, |
| | dataset_name=dataset_name[num], |
| | lang_embed_name=lang_embed_name[num], |
| | single_arm=single_arm |
| | ) |
| |
|
| | while True: |
| | try: |
| | if node.single_arm: |
| | img = cv2.cvtColor(node.obs['image'], cv2.COLOR_BGR2RGB) |
| | else: |
| | left_img = cv2.cvtColor(node.obs['leftview_image'], cv2.COLOR_BGR2RGB) |
| | right_img = cv2.cvtColor(node.obs['image'], cv2.COLOR_BGR2RGB) |
| | img = cv2.hconcat([left_img, right_img]) |
| | if node.start: |
| | node.window.show(img, overlay_img=None, text=node.obs['language_instruction']) |
| | else: |
| | |
| | node.boundary_query() |
| | node.window.show(img, overlay_img=node.overlay_img, text=node.obs['language_instruction'], grid=node.grid) |
| | except KeyboardInterrupt: |
| | node.ros_close() |
| | |
| | except Exception as e: |
| | print(f"An error occurred: {e}") |
| |
|
| | |