| """ |
| This file will compute the min, max, mean, and standard deviation of each datasets |
| in `pretrain_datasets.json` or `pretrain_datasets.json`. |
| """ |
|
|
| import json |
| import argparse |
| import os |
| |
|
|
| import tensorflow as tf |
| import numpy as np |
| from tqdm import tqdm |
|
|
| from data.vla_dataset import VLADataset |
| from data.hdf5_vla_dataset import HDF5VLADataset |
| from data.preprocess import generate_json_state |
|
|
|
|
| |
| @tf.autograph.experimental.do_not_convert |
| def process_dataset(name_dataset_pair): |
| |
| dataset_iter = name_dataset_pair[1] |
| |
| MAX_EPISODES = 100000 |
| EPS = 1e-8 |
| |
| |
| episode_cnt = 0 |
| state_sum = 0 |
| state_sum_sq = 0 |
| z_state_sum = 0 |
| z_state_sum_sq = 0 |
| state_cnt = 0 |
| nz_state_cnt = None |
| state_max = None |
| state_min = None |
| for episode in dataset_iter: |
| episode_cnt += 1 |
| if episode_cnt % 1000 == 0: |
| print(f"Processing episodes {episode_cnt}/{MAX_EPISODES}") |
| if episode_cnt > MAX_EPISODES: |
| break |
| episode_dict = episode['episode_dict'] |
| dataset_name = episode['dataset_name'] |
| |
| res_tup = generate_json_state( |
| episode_dict, dataset_name |
| ) |
| states = res_tup[1] |
| |
| |
| states = states.numpy() |
| |
| |
| z_states = states.copy() |
| z_states[np.abs(states) <= EPS] = 0 |
| |
| if nz_state_cnt is None: |
| nz_state_cnt = np.zeros(states.shape[1]) |
| nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0) |
| |
| |
| state_sum += np.sum(states, axis=0) |
| state_sum_sq += np.sum(states**2, axis=0) |
| z_state_sum += np.sum(z_states, axis=0) |
| z_state_sum_sq += np.sum(z_states**2, axis=0) |
| state_cnt += states.shape[0] |
| if state_max is None: |
| state_max = np.max(states, axis=0) |
| state_min = np.min(states, axis=0) |
| else: |
| state_max = np.maximum(state_max, np.max(states, axis=0)) |
| state_min = np.minimum(state_min, np.min(states, axis=0)) |
| |
| |
| nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt)) |
| |
| result = { |
| "dataset_name": name_dataset_pair[0], |
| "state_mean": (state_sum / state_cnt).tolist(), |
| "state_std": np.sqrt( |
| np.maximum( |
| (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt), |
| np.zeros_like(state_sum_sq) |
| ) |
| ).tolist(), |
| "state_min": state_min.tolist(), |
| "state_max": state_max.tolist(), |
| } |
|
|
| return result |
|
|
|
|
| def process_hdf5_dataset(vla_dataset): |
| EPS = 1e-8 |
| episode_cnt = 0 |
| state_sum = 0 |
| state_sum_sq = 0 |
| z_state_sum = 0 |
| z_state_sum_sq = 0 |
| state_cnt = 0 |
| nz_state_cnt = None |
| state_max = None |
| state_min = None |
| for i in tqdm(range(len(vla_dataset))): |
| episode = vla_dataset.get_item(i, state_only=True) |
| episode_cnt += 1 |
| |
| states = episode['state'] |
| |
| |
| z_states = states.copy() |
| z_states[np.abs(states) <= EPS] = 0 |
| |
| if nz_state_cnt is None: |
| nz_state_cnt = np.zeros(states.shape[1]) |
| nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0) |
| |
| |
| state_sum += np.sum(states, axis=0) |
| state_sum_sq += np.sum(states**2, axis=0) |
| z_state_sum += np.sum(z_states, axis=0) |
| z_state_sum_sq += np.sum(z_states**2, axis=0) |
| state_cnt += states.shape[0] |
| if state_max is None: |
| state_max = np.max(states, axis=0) |
| state_min = np.min(states, axis=0) |
| else: |
| state_max = np.maximum(state_max, np.max(states, axis=0)) |
| state_min = np.minimum(state_min, np.min(states, axis=0)) |
| |
| |
| nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt)) |
| |
| result = { |
| "dataset_name": vla_dataset.get_dataset_name(), |
| "state_mean": (state_sum / state_cnt).tolist(), |
| "state_std": np.sqrt( |
| np.maximum( |
| (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt), |
| np.zeros_like(state_sum_sq) |
| ) |
| ).tolist(), |
| "state_min": state_min.tolist(), |
| "state_max": state_max.tolist(), |
| } |
|
|
| return result |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| |
| |
| |
| parser.add_argument('--dataset_type', type=str, |
| default="pretrain", |
| help="Whether to load the pretrain dataset or finetune dataset.") |
| parser.add_argument('--save_path', type=str, |
| default="configs/dataset_stat.json", |
| help="JSON file path to save the dataset statistics.") |
| parser.add_argument('--skip_exist', action='store_true', |
| help="Whether to skip the existing dataset statistics.") |
| parser.add_argument('--hdf5_dataset', action='store_true', |
| help="Whether to load the dataset from the HDF5 files.") |
| args = parser.parse_args() |
| |
| if args.hdf5_dataset: |
| vla_dataset = HDF5VLADataset() |
| dataset_name = vla_dataset.get_dataset_name() |
| |
| try: |
| with open(args.save_path, 'r') as f: |
| results = json.load(f) |
| except FileNotFoundError: |
| results = {} |
| if args.skip_exist and dataset_name in results: |
| print(f"Skipping existed {dataset_name} dataset statistics") |
| else: |
| print(f"Processing {dataset_name} dataset") |
| result = process_hdf5_dataset(vla_dataset) |
| results[result["dataset_name"]] = result |
| with open(args.save_path, 'w') as f: |
| json.dump(results, f, indent=4) |
| print("All datasets have been processed.") |
| os._exit(0) |
| |
| vla_dataset = VLADataset( |
| seed=0, dataset_type=args.dataset_type, repeat=False) |
| name_dataset_pairs = vla_dataset.name2dataset.items() |
| |
| |
| for name_dataset_pair in tqdm(name_dataset_pairs): |
| try: |
| with open(args.save_path, 'r') as f: |
| results = json.load(f) |
| except FileNotFoundError: |
| results = {} |
| |
| if args.skip_exist and name_dataset_pair[0] in results: |
| print(f"Skipping existed {name_dataset_pair[0]} dataset statistics") |
| continue |
| print(f"Processing {name_dataset_pair[0]} dataset") |
| |
| result = process_dataset(name_dataset_pair) |
| |
| results[result["dataset_name"]] = result |
| |
| |
| with open(args.save_path, 'w') as f: |
| json.dump(results, f, indent=4) |
|
|
| print("All datasets have been processed.") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|