| | import pybullet as p |
| | import PySimpleGUI as sg |
| | import pickle |
| | from os import getcwd |
| | from urdfpy import URDF |
| | from os.path import abspath, dirname, basename, splitext |
| | from transforms3d.affines import decompose |
| | from transforms3d.quaternions import mat2quat |
| | import numpy as np |
| |
|
| |
|
| | class PyBulletRecorder: |
| | class LinkTracker: |
| | def __init__(self, |
| | name, |
| | body_id, |
| | link_id, |
| | link_origin, |
| | mesh_path, |
| | mesh_scale, |
| | mesh_material=None): |
| | self.body_id = body_id |
| | self.link_id = link_id |
| | self.mesh_path = mesh_path |
| | self.mesh_scale = mesh_scale |
| | self.mesh_material = mesh_material |
| | decomposed_origin = decompose(link_origin) |
| | orn = mat2quat(decomposed_origin[1]) |
| | orn = [orn[1], orn[2], orn[3], orn[0]] |
| | self.link_pose = [decomposed_origin[0], |
| | orn] |
| | self.name = name |
| |
|
| | def transform(self, position, orientation): |
| | return p.multiplyTransforms( |
| | position, orientation, |
| | self.link_pose[0], self.link_pose[1], |
| | ) |
| |
|
| | def get_keyframe(self): |
| | if self.link_id == -1: |
| | position, orientation = p.getBasePositionAndOrientation( |
| | self.body_id) |
| | position, orientation = self.transform( |
| | position=position, orientation=orientation) |
| | else: |
| | link_state = p.getLinkState(self.body_id, |
| | self.link_id, |
| | computeForwardKinematics=True) |
| | position, orientation = self.transform( |
| | position=link_state[4], |
| | orientation=link_state[5]) |
| | return { |
| | 'position': list(position), |
| | 'orientation': list(orientation) |
| | } |
| |
|
| | def __init__(self): |
| | self.states = [] |
| | self.links = [] |
| |
|
| | def register_object(self, body_id, urdf_path, global_scaling=1, color=None): |
| | link_id_map = dict() |
| | n = p.getNumJoints(body_id) |
| | link_id_map[str(p.getBodyInfo(body_id)[0].decode('gb2312'))] = -1 |
| |
|
| | for link_id in range(0, n): |
| | link_id_map[str(p.getJointInfo(body_id, link_id)[ |
| | 12].decode('gb2312'))] = link_id |
| |
|
| | dir_path = dirname(abspath(urdf_path)) |
| | file_name = splitext(basename(urdf_path))[0] |
| | robot = URDF.load(urdf_path) |
| | for link in robot.links: |
| | |
| | if link.name not in link_id_map: |
| | print("skip links !! ", link.name, link_id_map, len(robot.links), p.getBodyInfo(body_id)[0].decode('gb2312')) |
| | continue |
| |
|
| | link_id = link_id_map[link.name] |
| |
|
| | if len(link.visuals) > 0: |
| | for i, link_visual in enumerate(link.visuals): |
| | mesh_material = None |
| | if link_visual.material is not None: |
| | mesh_material = link_visual.material |
| | if color is not None: |
| | mesh_material.name = mesh_material.name + f"_{np.random.randint(100)}" |
| | mesh_material.color = color |
| |
|
| | if link_visual.geometry.mesh is not None: |
| | print("use mesh", i, link_id_map.keys()) |
| |
|
| | mesh_scale = [global_scaling, |
| | global_scaling, global_scaling]\ |
| | if link_visual.geometry.mesh.scale is None \ |
| | else link_visual.geometry.mesh.scale * global_scaling |
| |
|
| | self.links.append(('mesh', |
| | PyBulletRecorder.LinkTracker( |
| | name=file_name + f'_{body_id}_{link.name}_{i}', |
| | body_id=body_id, |
| | link_id=link_id, |
| | link_origin= |
| | |
| | |
| | |
| | (np.linalg.inv(link.inertial.origin) |
| | if link_id == -1 |
| | else np.identity(4)) @ |
| | link_visual.origin * global_scaling, |
| | mesh_path=dir_path + '/' + |
| | link_visual.geometry.mesh.filename, |
| | mesh_scale=mesh_scale, |
| | mesh_material=mesh_material))) |
| |
|
| | if link_visual.geometry.box is not None: |
| | print("use box", i, link_id_map.keys(), link_visual.geometry.box.__dict__) |
| | |
| | mesh_scale = link_visual.geometry.box.size / 2 |
| | self.links.append(('box', |
| | PyBulletRecorder.LinkTracker( |
| | name=file_name + f'_{body_id}_{link.name}_{i}', |
| | body_id=body_id, |
| | link_id=link_id, |
| | link_origin= (np.linalg.inv(link.inertial.origin) |
| | if link_id == -1 |
| | else np.identity(4)) @ |
| | link_visual.origin * global_scaling, |
| | mesh_path='box', |
| | mesh_scale=mesh_scale, |
| | mesh_material=mesh_material))) |
| |
|
| |
|
| | if link_visual.geometry.cylinder is not None: |
| | print("use cylinder", i, link_id_map.keys(), link_visual.geometry.cylinder.__dict__) |
| | mesh_scale = [link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.length] |
| | self.links.append(('cylinder', |
| | PyBulletRecorder.LinkTracker( |
| | name=file_name + f'_{body_id}_{link.name}_{i}', |
| | body_id=body_id, |
| | link_id=link_id, |
| | link_origin= (np.linalg.inv(link.inertial.origin) |
| | if link_id == -1 |
| | else np.identity(4)) @ |
| | link_visual.origin * global_scaling, |
| | mesh_path='cylinder', |
| | mesh_scale=mesh_scale, |
| | mesh_material=mesh_material))) |
| |
|
| |
|
| | if link_visual.geometry.sphere is not None: |
| | print("use sphere", i, link_id_map.keys(), link_visual.geometry.sphere.__dict__) |
| | mesh_scale = [link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius] |
| | self.links.append(('sphere', |
| | PyBulletRecorder.LinkTracker( |
| | name=file_name + f'_{body_id}_{link.name}_{i}', |
| | body_id=body_id, |
| | link_id=link_id, |
| | link_origin= (np.linalg.inv(link.inertial.origin) |
| | if link_id == -1 |
| | else np.identity(4)) @ |
| | link_visual.origin * global_scaling, |
| | mesh_path='sphere', |
| | mesh_scale=mesh_scale, |
| | mesh_material=mesh_material))) |
| |
|
| | def add_keyframe(self): |
| | |
| | current_state = {} |
| | for name, link in self.links: |
| | current_state[link.name] = link.get_keyframe() |
| | self.states.append(current_state) |
| |
|
| | def prompt_save(self): |
| | layout = [[sg.Text('Do you want to save previous episode?')], |
| | [sg.Button('Yes'), sg.Button('No')]] |
| | window = sg.Window('PyBullet Recorder', layout) |
| | save = False |
| | while True: |
| | event, values = window.read() |
| | if event in (None, 'No'): |
| | break |
| | elif event == 'Yes': |
| | save = True |
| | break |
| | window.close() |
| |
|
| | if save: |
| | layout = [[sg.Text('Where do you want to save it?')], |
| | [sg.Text('Path'), sg.InputText(getcwd())], |
| | [sg.Button('OK')]] |
| | window = sg.Window('PyBullet Recorder', layout) |
| | event, values = window.read() |
| | window.close() |
| | self.save(values[0]) |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.states = [] |
| |
|
| | def get_formatted_output(self): |
| | retval = {} |
| | for geo_name, link in self.links: |
| | if geo_name == 'mesh': |
| | retval[link.name] = { |
| | 'type': 'mesh', |
| | 'mesh_path': link.mesh_path, |
| | 'mesh_scale': link.mesh_scale, |
| | 'frames': [state[link.name] for state in self.states] |
| | } |
| | if geo_name == 'box': |
| | |
| | retval[link.name] = { |
| | 'type': 'cube', |
| | 'name': link.name, |
| | 'mesh_scale': link.mesh_scale, |
| | 'frames': [state[link.name] for state in self.states] |
| | } |
| | if geo_name == 'cylinder': |
| | retval[link.name] = { |
| | 'type': 'cylinder', |
| | 'name': link.name, |
| | 'mesh_scale': link.mesh_scale, |
| | 'frames': [state[link.name] for state in self.states] |
| | } |
| | if geo_name == 'sphere': |
| | retval[link.name] = { |
| | 'type': 'sphere', |
| | 'name': link.name, |
| | 'mesh_scale': link.mesh_scale, |
| | 'frames': [state[link.name] for state in self.states] |
| | } |
| | if link.mesh_material is not None: |
| | retval[link.name]['mesh_material_name'] = link.mesh_material.name |
| | retval[link.name] ['mesh_material_color'] = link.mesh_material.color |
| |
|
| | return retval |
| |
|
| | def save(self, path): |
| | if path is None: |
| | print("[Recorder] Path is None.. not saving") |
| | else: |
| | print("[Recorder] Saving state to {}".format(path)) |
| | pickle.dump(self.get_formatted_output(), open(path, 'wb')) |
| |
|