| | import json |
| |
|
| | import tensorflow as tf |
| | import yaml |
| |
|
| | from data.preprocess_scripts import * |
| | from configs.state_vec import STATE_VEC_IDX_MAPPING, STATE_VEC_LEN |
| | from data.utils import capitalize_and_period |
| |
|
| | |
| | DATASET_NAMES_NO_STATE = [ |
| | 'nyu_door_opening_surprising_effectiveness', |
| | "usc_cloth_sim_converted_externally_to_rlds", |
| | 'cmu_franka_exploration_dataset_converted_externally_to_rlds', |
| | 'imperialcollege_sawyer_wrist_cam' |
| | ] |
| |
|
| | |
| | with open('configs/dataset_img_keys.json', 'r') as file: |
| | IMAGE_KEYS = json.load(file) |
| | |
| | with open('configs/base.yaml', 'r') as file: |
| | config = yaml.safe_load(file) |
| |
|
| |
|
| | def assemble_state_vec(arm_concat: tf.Tensor, arm_format: str, |
| | base_concat=None, base_format=None) -> tf.Tensor: |
| | """ |
| | Assemble the state/action vector from the arm and base. |
| | """ |
| | state_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32) |
| | mask_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32) |
| |
|
| | |
| | arm_concat = tf.cast(arm_concat, tf.float32) |
| | arm_format = arm_format.split(',') |
| | |
| | state_vec = tf.tensor_scatter_nd_update( |
| | state_vec, |
| | [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format], |
| | arm_concat |
| | ) |
| | mask_vec = tf.tensor_scatter_nd_update( |
| | mask_vec, |
| | [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format], |
| | tf.ones(len(arm_format), dtype=tf.float32) |
| | ) |
| |
|
| | |
| | if base_concat is not None: |
| | base_concat = tf.cast(base_concat, tf.float32) |
| | base_format = base_format.split(',') |
| | state_vec = tf.tensor_scatter_nd_update( |
| | state_vec, |
| | [[STATE_VEC_IDX_MAPPING[name]] for name in base_format], |
| | base_concat |
| | ) |
| | mask_vec = tf.tensor_scatter_nd_update( |
| | mask_vec, |
| | [[STATE_VEC_IDX_MAPPING[name]] for name in base_format], |
| | tf.ones(len(base_format), dtype=tf.float32) |
| | ) |
| | return state_vec, mask_vec |
| |
|
| |
|
| | @tf.autograph.experimental.do_not_convert |
| | def _generate_json_state_agilex(episode: dict, dataset_name: str): |
| | """ |
| | Generate the json dict and state for a given episode. |
| | """ |
| | |
| | IMG_HISTORY_SIZE = config['common']['img_history_size'] |
| | if IMG_HISTORY_SIZE < 1: |
| | raise ValueError("Config `img_history_size` must be at least 1.") |
| | ACTION_CHUNK_SIZE = config['common']['action_chunk_size'] |
| | if ACTION_CHUNK_SIZE < 1: |
| | raise ValueError("Config `action_chunk_size` must be at least 1.") |
| |
|
| | |
| | episode_metadata = { |
| | 'dataset_name': dataset_name, |
| | '#steps': 0, |
| | 'instruction': None |
| | } |
| | |
| | |
| | base_act = None |
| | last_base_act = None |
| | episode_states = [] |
| | episode_acts = [] |
| | episode_masks = [] |
| | has_base = None |
| | for step_id, step in enumerate(iter(episode['steps'])): |
| | |
| | action = step['action'] |
| | if has_base is None: |
| | has_base = 'base_concat' in action |
| | if has_base: |
| | base_act = action['base_concat'] |
| | |
| | |
| | state = step['observation'] |
| |
|
| | arm_format = state['format'].numpy().decode('utf-8') |
| | base_format = None |
| | if has_base: |
| | act_format = action['format'].numpy().decode('utf-8') |
| | base_formate_idx = act_format.find('base') |
| | base_format = act_format[base_formate_idx:] |
| |
|
| | arm_state = state['arm_concat'] |
| | base_state = None |
| | if has_base: |
| | if last_base_act is None: |
| | base_state = base_act * 0 |
| | else: |
| | base_state = last_base_act |
| | last_base_act = base_act |
| |
|
| | |
| | state_vec, mask_vec = assemble_state_vec( |
| | arm_state, arm_format, base_state, base_format) |
| | |
| | |
| | act_vec, mask_vec = assemble_state_vec( |
| | action['arm_concat'], arm_format, base_state, base_format |
| | ) |
| | |
| | episode_states.append(state_vec) |
| | episode_masks.append(mask_vec) |
| | episode_acts.append(act_vec) |
| |
|
| | |
| | instr = step['observation']['natural_language_instruction'] |
| | instr = instr.numpy().decode('utf-8') |
| | instr = capitalize_and_period(instr) |
| | |
| | |
| | if episode_metadata['instruction'] is None: |
| | episode_metadata['instruction'] = instr |
| |
|
| | episode_metadata['#steps'] = step_id |
| | |
| | episode_states = tf.stack(episode_states) |
| | episode_masks = tf.stack(episode_masks) |
| | episode_acts = tf.stack(episode_acts) |
| |
|
| | return episode_metadata, episode_states, episode_masks, episode_acts |
| |
|
| |
|
| | @tf.autograph.experimental.do_not_convert |
| | def _generate_json_state(episode: dict, dataset_name: str): |
| | """ |
| | Generate the json dict and state for a given episode. |
| | """ |
| | |
| | IMG_HISTORY_SIZE = config['common']['img_history_size'] |
| | if IMG_HISTORY_SIZE < 1: |
| | raise ValueError("Config `img_history_size` must be at least 1.") |
| | ACTION_CHUNK_SIZE = config['common']['action_chunk_size'] |
| | if ACTION_CHUNK_SIZE < 1: |
| | raise ValueError("Config `action_chunk_size` must be at least 1.") |
| |
|
| | |
| | episode_metadata = { |
| | 'dataset_name': dataset_name, |
| | '#steps': 0, |
| | 'instruction': None |
| | } |
| | |
| | |
| | base_act = None |
| | last_base_act = None |
| | episode_states = [] |
| | episode_masks = [] |
| | has_base = None |
| | for step_id, step in enumerate(iter(episode['steps'])): |
| | |
| | action = step['action'] |
| | if has_base is None: |
| | has_base = 'base_concat' in action |
| | if has_base: |
| | base_act = action['base_concat'] |
| | |
| | |
| | state = step['observation'] |
| |
|
| | arm_format = state['format'].numpy().decode('utf-8') |
| | base_format = None |
| | if has_base: |
| | act_format = action['format'].numpy().decode('utf-8') |
| | base_formate_idx = act_format.find('base') |
| | base_format = act_format[base_formate_idx:] |
| |
|
| | arm_state = state['arm_concat'] |
| | base_state = None |
| | if has_base: |
| | if last_base_act is None: |
| | base_state = base_act * 0 |
| | else: |
| | base_state = last_base_act |
| | last_base_act = base_act |
| |
|
| | |
| | state_vec, mask_vec = assemble_state_vec( |
| | arm_state, arm_format, base_state, base_format) |
| | |
| | episode_states.append(state_vec) |
| | episode_masks.append(mask_vec) |
| |
|
| | |
| | instr = step['observation']['natural_language_instruction'] |
| | instr = instr.numpy().decode('utf-8') |
| | instr = capitalize_and_period(instr) |
| | |
| | |
| | if episode_metadata['instruction'] is None: |
| | episode_metadata['instruction'] = instr |
| | |
| | episode_metadata['#steps'] = step_id |
| | episode_states = tf.stack(episode_states) |
| | episode_masks = tf.stack(episode_masks) |
| |
|
| | return episode_metadata, episode_states, episode_masks |
| |
|
| |
|
| | @tf.autograph.experimental.do_not_convert |
| | def _generate_json_state_nostate_ds(episode: dict, dataset_name: str): |
| | """ |
| | Generate the json dict and state for an episode in the dataset without state. |
| | If not state, we use the last action as current state. |
| | """ |
| | |
| | IMG_HISTORY_SIZE = config['common']['img_history_size'] |
| | if IMG_HISTORY_SIZE < 1: |
| | raise ValueError("Config `img_history_size` must be at least 1.") |
| | ACTION_CHUNK_SIZE = config['common']['action_chunk_size'] |
| | if ACTION_CHUNK_SIZE < 1: |
| | raise ValueError("Config `action_chunk_size` must be at least 1.") |
| |
|
| | |
| | episode_metadata = { |
| | 'dataset_name': dataset_name, |
| | '#steps': 0, |
| | 'instruction': None |
| | } |
| | |
| | last_base_act = None |
| | last_arm_act = None |
| | episode_states = [] |
| | episode_masks = [] |
| | has_base = None |
| | for step_id, step in enumerate(iter(episode['steps'])): |
| | |
| | action = step['action'] |
| | if has_base is None: |
| | has_base = 'base_concat' in action |
| | if has_base: |
| | base_act = action['base_concat'] |
| | if last_base_act is None: |
| | last_base_act = base_act * 0 |
| |
|
| | |
| | arm_act = action['arm_concat'] |
| | if last_arm_act is None: |
| | last_arm_act = arm_act * 0 |
| |
|
| | |
| | |
| | act_format = action['format'].numpy().decode('utf-8') |
| |
|
| | |
| | if has_base: |
| | last_act_concat = tf.concat([last_arm_act, last_base_act], axis=0) |
| | else: |
| | last_act_concat = last_arm_act |
| | state_vec, mask_vec = assemble_state_vec( |
| | last_act_concat, act_format) |
| |
|
| | episode_states.append(state_vec) |
| | episode_masks.append(mask_vec) |
| |
|
| | |
| | instr = step['observation']['natural_language_instruction'] |
| | instr = instr.numpy().decode('utf-8') |
| | instr = capitalize_and_period(instr) |
| | |
| | |
| | if episode_metadata['instruction'] is None: |
| | episode_metadata['instruction'] = instr |
| |
|
| | |
| | last_arm_act = arm_act |
| | if has_base: |
| | last_base_act = base_act |
| | |
| | episode_metadata['#steps'] = step_id |
| | episode_states = tf.stack(episode_states) |
| | episode_masks = tf.stack(episode_masks) |
| | |
| | return episode_metadata, episode_states, episode_masks |
| |
|
| |
|
| | @tf.autograph.experimental.do_not_convert |
| | def generate_json_state(episode: dict, dataset_name: str): |
| | """ |
| | Generate the json dict and state for an episode. |
| | """ |
| | if isinstance(dataset_name, tf.Tensor): |
| | dataset_name = dataset_name.numpy().decode('utf-8') |
| |
|
| | |
| | episode['steps'] = episode['steps'].map( |
| | globals()[dataset_name].process_step, |
| | ) |
| |
|
| | if dataset_name == "agilex": |
| | return _generate_json_state_agilex(episode, dataset_name) |
| | |
| | if dataset_name in DATASET_NAMES_NO_STATE: |
| | return _generate_json_state_nostate_ds(episode, dataset_name) |
| |
|
| | return _generate_json_state(episode, dataset_name) |
| |
|