| import pybullet as p |
| import PySimpleGUI as sg |
| import pickle |
| from os import getcwd |
| from urdfpy import URDF |
| from os.path import abspath, dirname, basename, splitext |
| from transforms3d.affines import decompose |
| from transforms3d.quaternions import mat2quat |
| import numpy as np |
|
|
|
|
| class PyBulletRecorder: |
| class LinkTracker: |
| def __init__(self, |
| name, |
| body_id, |
| link_id, |
| link_origin, |
| mesh_path, |
| mesh_scale, |
| mesh_material=None): |
| self.body_id = body_id |
| self.link_id = link_id |
| self.mesh_path = mesh_path |
| self.mesh_scale = mesh_scale |
| self.mesh_material = mesh_material |
| decomposed_origin = decompose(link_origin) |
| orn = mat2quat(decomposed_origin[1]) |
| orn = [orn[1], orn[2], orn[3], orn[0]] |
| self.link_pose = [decomposed_origin[0], |
| orn] |
| self.name = name |
|
|
| def transform(self, position, orientation): |
| return p.multiplyTransforms( |
| position, orientation, |
| self.link_pose[0], self.link_pose[1], |
| ) |
|
|
| def get_keyframe(self): |
| if self.link_id == -1: |
| position, orientation = p.getBasePositionAndOrientation( |
| self.body_id) |
| position, orientation = self.transform( |
| position=position, orientation=orientation) |
| else: |
| link_state = p.getLinkState(self.body_id, |
| self.link_id, |
| computeForwardKinematics=True) |
| position, orientation = self.transform( |
| position=link_state[4], |
| orientation=link_state[5]) |
| return { |
| 'position': list(position), |
| 'orientation': list(orientation) |
| } |
|
|
| def __init__(self): |
| self.states = [] |
| self.links = [] |
|
|
| def register_object(self, body_id, urdf_path, global_scaling=1, color=None): |
| link_id_map = dict() |
| n = p.getNumJoints(body_id) |
| link_id_map[str(p.getBodyInfo(body_id)[0].decode('gb2312'))] = -1 |
|
|
| for link_id in range(0, n): |
| link_id_map[str(p.getJointInfo(body_id, link_id)[ |
| 12].decode('gb2312'))] = link_id |
|
|
| dir_path = dirname(abspath(urdf_path)) |
| file_name = splitext(basename(urdf_path))[0] |
| robot = URDF.load(urdf_path) |
| for link in robot.links: |
| |
| if link.name not in link_id_map: |
| print("skip links !! ", link.name, link_id_map, len(robot.links), p.getBodyInfo(body_id)[0].decode('gb2312')) |
| continue |
|
|
| link_id = link_id_map[link.name] |
|
|
| if len(link.visuals) > 0: |
| for i, link_visual in enumerate(link.visuals): |
| mesh_material = None |
| if link_visual.material is not None: |
| mesh_material = link_visual.material |
| if color is not None: |
| mesh_material.name = mesh_material.name + f"_{np.random.randint(100)}" |
| mesh_material.color = color |
|
|
| if link_visual.geometry.mesh is not None: |
| print("use mesh", i, link_id_map.keys()) |
|
|
| mesh_scale = [global_scaling, |
| global_scaling, global_scaling]\ |
| if link_visual.geometry.mesh.scale is None \ |
| else link_visual.geometry.mesh.scale * global_scaling |
|
|
| self.links.append(('mesh', |
| PyBulletRecorder.LinkTracker( |
| name=file_name + f'_{body_id}_{link.name}_{i}', |
| body_id=body_id, |
| link_id=link_id, |
| link_origin= |
| |
| |
| |
| (np.linalg.inv(link.inertial.origin) |
| if link_id == -1 |
| else np.identity(4)) @ |
| link_visual.origin * global_scaling, |
| mesh_path=dir_path + '/' + |
| link_visual.geometry.mesh.filename, |
| mesh_scale=mesh_scale, |
| mesh_material=mesh_material))) |
|
|
| if link_visual.geometry.box is not None: |
| print("use box", i, link_id_map.keys(), link_visual.geometry.box.__dict__) |
| |
| mesh_scale = link_visual.geometry.box.size / 2 |
| self.links.append(('box', |
| PyBulletRecorder.LinkTracker( |
| name=file_name + f'_{body_id}_{link.name}_{i}', |
| body_id=body_id, |
| link_id=link_id, |
| link_origin= (np.linalg.inv(link.inertial.origin) |
| if link_id == -1 |
| else np.identity(4)) @ |
| link_visual.origin * global_scaling, |
| mesh_path='box', |
| mesh_scale=mesh_scale, |
| mesh_material=mesh_material))) |
|
|
|
|
| if link_visual.geometry.cylinder is not None: |
| print("use cylinder", i, link_id_map.keys(), link_visual.geometry.cylinder.__dict__) |
| mesh_scale = [link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.length] |
| self.links.append(('cylinder', |
| PyBulletRecorder.LinkTracker( |
| name=file_name + f'_{body_id}_{link.name}_{i}', |
| body_id=body_id, |
| link_id=link_id, |
| link_origin= (np.linalg.inv(link.inertial.origin) |
| if link_id == -1 |
| else np.identity(4)) @ |
| link_visual.origin * global_scaling, |
| mesh_path='cylinder', |
| mesh_scale=mesh_scale, |
| mesh_material=mesh_material))) |
|
|
|
|
| if link_visual.geometry.sphere is not None: |
| print("use sphere", i, link_id_map.keys(), link_visual.geometry.sphere.__dict__) |
| mesh_scale = [link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius] |
| self.links.append(('sphere', |
| PyBulletRecorder.LinkTracker( |
| name=file_name + f'_{body_id}_{link.name}_{i}', |
| body_id=body_id, |
| link_id=link_id, |
| link_origin= (np.linalg.inv(link.inertial.origin) |
| if link_id == -1 |
| else np.identity(4)) @ |
| link_visual.origin * global_scaling, |
| mesh_path='sphere', |
| mesh_scale=mesh_scale, |
| mesh_material=mesh_material))) |
|
|
| def add_keyframe(self): |
| |
| current_state = {} |
| for name, link in self.links: |
| current_state[link.name] = link.get_keyframe() |
| self.states.append(current_state) |
|
|
| def prompt_save(self): |
| layout = [[sg.Text('Do you want to save previous episode?')], |
| [sg.Button('Yes'), sg.Button('No')]] |
| window = sg.Window('PyBullet Recorder', layout) |
| save = False |
| while True: |
| event, values = window.read() |
| if event in (None, 'No'): |
| break |
| elif event == 'Yes': |
| save = True |
| break |
| window.close() |
|
|
| if save: |
| layout = [[sg.Text('Where do you want to save it?')], |
| [sg.Text('Path'), sg.InputText(getcwd())], |
| [sg.Button('OK')]] |
| window = sg.Window('PyBullet Recorder', layout) |
| event, values = window.read() |
| window.close() |
| self.save(values[0]) |
| self.reset() |
|
|
| def reset(self): |
| self.states = [] |
|
|
| def get_formatted_output(self): |
| retval = {} |
| for geo_name, link in self.links: |
| if geo_name == 'mesh': |
| retval[link.name] = { |
| 'type': 'mesh', |
| 'mesh_path': link.mesh_path, |
| 'mesh_scale': link.mesh_scale, |
| 'frames': [state[link.name] for state in self.states] |
| } |
| if geo_name == 'box': |
| |
| retval[link.name] = { |
| 'type': 'cube', |
| 'name': link.name, |
| 'mesh_scale': link.mesh_scale, |
| 'frames': [state[link.name] for state in self.states] |
| } |
| if geo_name == 'cylinder': |
| retval[link.name] = { |
| 'type': 'cylinder', |
| 'name': link.name, |
| 'mesh_scale': link.mesh_scale, |
| 'frames': [state[link.name] for state in self.states] |
| } |
| if geo_name == 'sphere': |
| retval[link.name] = { |
| 'type': 'sphere', |
| 'name': link.name, |
| 'mesh_scale': link.mesh_scale, |
| 'frames': [state[link.name] for state in self.states] |
| } |
| if link.mesh_material is not None: |
| retval[link.name]['mesh_material_name'] = link.mesh_material.name |
| retval[link.name] ['mesh_material_color'] = link.mesh_material.color |
|
|
| return retval |
|
|
| def save(self, path): |
| if path is None: |
| print("[Recorder] Path is None.. not saving") |
| else: |
| print("[Recorder] Saving state to {}".format(path)) |
| pickle.dump(self.get_formatted_output(), open(path, 'wb')) |
|
|