| from collections import defaultdict |
|
|
| import numpy as np |
| from prismatic.vla.action_tokenizer import ActionTokenizer |
| from transformers import AutoTokenizer |
|
|
|
|
| class Solver: |
| def __init__(self, action_tokenizer=None, verbose=True) -> None: |
| self.verbose = verbose |
| self.action_tokenizer = action_tokenizer |
| self.coordinates_key = "NEXT GRIPPER:" |
| self.movement_key = "MOVEMENT:" |
| self.policy_key = "POLICIES:" |
|
|
| def compare_movement(self, pred_pos, label_pos): |
|
|
| dist = np.sum(np.abs(pred_pos - label_pos)) |
| relative_dist = np.sum(np.abs(dist / label_pos)) |
| return dist, relative_dist, dist == 0 |
|
|
| def compare_policy(self, pred_pol, label_pol): |
| dist = 0 |
| cnt = 0 |
| for i in range(min(len(label_pol), len(pred_pol))): |
| for j in range(len(label_pol[0])): |
| dist += label_pol[i][j] == pred_pol[i][j] |
| cnt += 1 |
| assert cnt % 7 == 0 |
| return dist / cnt |
|
|
| def extract_2d_coordinates(self, text): |
| try: |
| coordinates_index = text.index(self.coordinates_key) + len(self.coordinates_key) |
| coord = text[coordinates_index:] |
| coord = [o for o in coord.split("\n") if len(o.strip()) != 0] |
| coord = eval(coord[0].strip()) |
| except Exception: |
| coord = [0, 0] |
| return coord |
|
|
| def extract_movement_plan(self, text): |
| require_unorm = None |
| try: |
| |
| movement_index = text.index(self.movement_key) + len(self.movement_key) |
| movement_level = text[movement_index:] |
| movement_level = [o for o in movement_level.split("\n") if len(o.strip()) != 0] |
| movement_level = movement_level[0].strip() |
|
|
| if "gripper" not in movement_level: |
| require_unorm = True |
| movement_token_ids = self.action_tokenizer.tokenizer(movement_level, add_special_tokens=False).input_ids |
| movement_norm = self.action_tokenizer.decode_token_ids_to_actions(np.array(movement_token_ids)) |
| movement_norm = movement_norm[1:8] |
| assert len(movement_norm) == 7 |
| else: |
| require_unorm = False |
| movement_level = [o for o in movement_level.split(";") if len(o) > 0] |
| movement_level = movement_level[:7] |
|
|
| position = defaultdict(int) |
| movement_to_pos = dict( |
| move_backward=(-1, "y"), |
| move_forward=(1, "y"), |
| move_right=(-1, "x"), |
| move_left=(1, "x"), |
| move_downward=(-1, "z"), |
| move_upward=(1, "z"), |
| roll_downward=(-1, "ox"), |
| roll_upward=(1, "ox"), |
| swing_downward=(-1, "ox"), |
| swing_upward=(1, "ox"), |
| pitch_downward=(-1, "oy"), |
| pitch_upward=(1, "oy"), |
| yaw_downward=(-1, "oz"), |
| yaw_upward=(1, "oz"), |
| rotate_clockwise=(-1, "oz"), |
| rotate_counterclockwise=(1, "oz"), |
| close_gripper=(-1, "grip"), |
| open_gripper=(1, "grip"), |
| ) |
|
|
| for ml in movement_level: |
| direction = "_".join(ml.split()[:2]) |
| sign, axis = movement_to_pos[direction] |
| scale = 1 |
| if "o" in axis: |
| scale = scale * 1e-3 |
| elif "grip" in axis: |
| scale = scale |
| else: |
| scale = scale / 180 * np.pi |
|
|
| if "grip" in axis: |
| level = round("open" in ml) |
| else: |
| level = int(ml.split()[2]) |
|
|
| position[axis] += sign * scale * level |
| movement_norm = [position[idx] for idx in ["x", "y", "z", "ox", "oy", "oz", "grip"]] |
|
|
| except: |
| movement_norm = [-100] * 7 |
|
|
| return require_unorm, np.array(movement_norm) |
|
|
| def extract_action_policies(self, text): |
| try: |
| if self.policy_key in text: |
|
|
| policy_index = text.index(self.policy_key) + len(self.policy_key) |
| policy = text[policy_index:] |
| remain_text = text[: text.index(self.policy_key)] |
| policies = [o for o in policy.split("\n") if len(o.strip()) != 0] |
| policies = policies[0].strip() |
| else: |
| policies = text.strip() |
| remain_text = "" |
|
|
| policies_num = [] |
| for policy_text in policies.split(";"): |
| policy_token = self.action_tokenizer.tokenizer(policy_text, add_special_tokens=False).input_ids |
| action_policy = self.action_tokenizer.decode_token_ids_to_actions(np.array(policy_token)) |
| |
| action_policy = action_policy[1:] |
| action_policy = action_policy[:7] |
| |
| if len(action_policy) != 7: |
| action_policy = [0] * 7 |
| policies_num.append(action_policy.tolist()) |
|
|
| except: |
| policies_num = [[0] * 7] |
| remain_text = text |
|
|
| return policies_num, remain_text |
|
|
| def evaluate_single(self, ground_truth, prediction, verbose=False): |
| gt_policies, ground_truth = self.extract_action_policies(ground_truth) |
| pred_policies, prediction = self.extract_action_policies(prediction) |
|
|
| _, pred_movement = self.extract_movement_plan(prediction) |
| _, gt_movement = self.extract_movement_plan(ground_truth) |
|
|
| dist, relative_dist, _ = self.compare_movement(label_pos=gt_movement, pred_pos=pred_movement) |
|
|
| |
| |
|
|
| next_state_score = 0 |
|
|
| acc = self.compare_policy(label_pol=gt_policies, pred_pol=pred_policies) |
|
|
| return next_state_score, acc, dist, relative_dist, pred_policies, gt_policies |
|
|
| def evaluate_batch(self, batch_gt, batch_pred, verbose=False): |
| state_acc_ls = [] |
| action_acc_ls = [] |
| L1_loss_ls = [] |
| relative_L1_loss_ls = [] |
| pred_policies_ls = [] |
| gt_policies_ls = [] |
| for i in range(len(batch_gt)): |
| ground_truth = batch_gt[i] |
| prediction = batch_pred[i] |
| next_state_score, action_policy_score, L1_dist, relative_L1_dist, pred_policies, gt_policies = ( |
| self.evaluate_single(ground_truth, prediction) |
| ) |
| state_acc_ls.append(next_state_score) |
| action_acc_ls.append(action_policy_score) |
| L1_loss_ls.append(L1_dist) |
| relative_L1_loss_ls.append(relative_L1_dist) |
| pred_policies_ls.append(pred_policies) |
| gt_policies_ls.append(gt_policies) |
| if verbose: |
| print(f"Ground Truth:\n\n {ground_truth}") |
| print() |
| print(f"prediction:\n\n {prediction}") |
| print() |
| print(f"Ground Truth Policies:\n\n {gt_policies}") |
| print(f"prediction policies:\n\n {pred_policies}") |
| print("*" * 40) |
|
|
| return state_acc_ls, action_acc_ls, L1_loss_ls, relative_L1_loss_ls, pred_policies_ls, gt_policies_ls |
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", model_max_length=2048, padding_side="right") |
| action_tokenizer = ActionTokenizer(tokenizer) |
| solver = Solver(action_tokenizer) |
|
|
|
|