| | import utils |
| | import torch |
| | import einops |
| | import numpy as np |
| | from workspaces import base |
| | from utils import get_split_idx |
| | from accelerate import Accelerator |
| |
|
| | OBS_ELEMENT_INDICES = { |
| | "robot": np.array([0, 1, 2, 3, 4, 5, 6]), |
| | "bottom burner": np.array([11, 12]), |
| | "top burner": np.array([15, 16]), |
| | "light switch": np.array([17, 18]), |
| | "slide cabinet": np.array([19]), |
| | "hinge cabinet": np.array([20, 21]), |
| | "microwave": np.array([22]), |
| | "kettle": np.array([23, 24, 25, 26, 27, 28, 29]), |
| | } |
| | accelerator = Accelerator() |
| |
|
| |
|
| | def calc_state_dist(a, b): |
| | result = {} |
| | for k, v in OBS_ELEMENT_INDICES.items(): |
| | idx = torch.Tensor(v).long() |
| | result[k] = ((a[idx] - b[idx]) ** 2).mean() |
| | result["total"] = ((a - b) ** 2).mean() |
| | return result |
| |
|
| |
|
| | def mean_dicts(dicts): |
| | result = {} |
| | for k in dicts[0].keys(): |
| | result[k] = np.mean([x[k] for x in dicts]) |
| | return result |
| |
|
| |
|
| | class SimKitchenWorkspace(base.Workspace): |
| | def __init__(self, cfg, work_dir): |
| | super().__init__(cfg, work_dir) |
| |
|
| | def run_offline_eval(self): |
| | train_idx, val_idx = get_split_idx( |
| | len(self.dataset), |
| | self.cfg.seed, |
| | train_fraction=self.cfg.train_fraction, |
| | ) |
| |
|
| | embeddings = utils.inference.embed_trajectory_dataset( |
| | self.encoder, self.dataset |
| | ) |
| | embeddings = [ |
| | einops.rearrange(x, "T V E -> T (V E)") for x in embeddings |
| | ] |
| | if self.accelerator.is_main_process: |
| | states = [] |
| | actions = [] |
| | for i in range(len(self.dataset)): |
| | T = self.dataset.get_seq_length(i) |
| | states.append(self.dataset.states[i, :T, :30]) |
| | actions.append(self.dataset.actions[i, :T]) |
| | embd_state_linear_probe_results = ( |
| | utils.inference.linear_probe_with_trajectory_split( |
| | embeddings, |
| | states, |
| | train_idx, |
| | val_idx, |
| | ) |
| | ) |
| | |
| | embd_state_linear_probe_results = { |
| | f"embd_state_{k}": v for k, v in embd_state_linear_probe_results.items() |
| | } |
| | embd_action_linear_probe_results = ( |
| | utils.inference.linear_probe_with_trajectory_split( |
| | embeddings, |
| | actions, |
| | train_idx, |
| | val_idx, |
| | ) |
| | ) |
| | embd_action_linear_probe_results = { |
| | f"embd_action_{k}": v |
| | for k, v in embd_action_linear_probe_results.items() |
| | } |
| |
|
| | state_dists = [] |
| | N = 200 |
| | rng = np.random.default_rng(self.cfg.seed) |
| | for i in range(N): |
| | query_traj_idx = rng.choice(len(self.dataset)) |
| | query_frame_idx = rng.choice( |
| | range(10, self.dataset.get_seq_length(query_traj_idx)) |
| | ) |
| | query_embedding = embeddings[query_traj_idx][query_frame_idx] |
| | query_frame_state = self.dataset.states[ |
| | query_traj_idx, query_frame_idx |
| | ][:30] |
| |
|
| | pool_embeddings = torch.cat( |
| | [x for i, x in enumerate(embeddings) if i != query_traj_idx] |
| | ) |
| | pool_states = torch.cat( |
| | [x for i, x in enumerate(states) if i != query_traj_idx] |
| | ) |
| | _, nn_idx = utils.inference.batch_knn( |
| | query_embedding.unsqueeze(0), |
| | pool_embeddings, |
| | metric=utils.inference.mse, |
| | k=1, |
| | batch_size=1, |
| | ) |
| | closest_frame_state = pool_states[nn_idx[0, 0]][:30] |
| | state_dist = calc_state_dist(query_frame_state, closest_frame_state) |
| | state_dists.append(state_dist) |
| | mean_state_dist = mean_dicts(state_dists) |
| | return { |
| | **embd_state_linear_probe_results, |
| | **embd_action_linear_probe_results, |
| | **mean_state_dist, |
| | } |
| | else: |
| | return None |
| |
|