File size: 5,987 Bytes
9ad6280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python3
"""
Evaluate a Pi0.5 checkpoint in the SO-100 MuJoCo sim.
Renders a video of the model controlling the arm.

Usage:
  python eval_sim.py --checkpoint outputs/scale_up_1k/checkpoints/000500/pretrained_model
  python eval_sim.py --checkpoint lerobot/pi05_base  # test base model
"""

import argparse
import sys
from pathlib import Path

import imageio
import numpy as np
import torch

sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, str(Path.home() / "lerobot" / "src"))

from gym_so100.env import SO100Env
from gym_so100.constants import normalize_lerobot_to_gym_so100


def load_policy(checkpoint_path, device="cuda"):
    """Load Pi0.5 policy from checkpoint."""
    from lerobot.policies.pi05.modeling_pi05 import PI05Policy

    print(f"Loading policy from {checkpoint_path}...")
    policy = PI05Policy.from_pretrained(str(checkpoint_path))
    policy = policy.to(device)
    policy.eval()
    return policy


def build_batch(obs, task, stats, device="cuda"):
    """Convert sim observation to Pi0.5 batch format."""
    # Image: sim gives (H, W, 3) uint8 -> (1, 3, H, W) float32 [0,1]
    image = torch.from_numpy(obs["pixels"]).permute(2, 0, 1).float() / 255.0
    image = image.unsqueeze(0)  # add batch dim

    # Resize to 224x224
    import torchvision.transforms.functional as TF
    image_224 = TF.resize(image, [224, 224], antialias=True)

    # ImageNet normalization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    image_224 = (image_224 - mean) / std

    # State: sim gives radians, convert to degrees (LeRobot scale)
    agent_pos = obs["agent_pos"].copy()
    agent_pos_degrees = np.degrees(agent_pos)
    state = torch.tensor(agent_pos_degrees, dtype=torch.float32).unsqueeze(0)

    # Normalize state with our stats
    state_mean = torch.tensor(stats["observation.state"]["mean"], dtype=torch.float32)
    state_std = torch.tensor(stats["observation.state"]["std"], dtype=torch.float32)
    state = (state - state_mean) / (state_std + 1e-8)

    # Pad state to 32 dims
    state_padded = torch.zeros(1, 32)
    state_padded[:, :6] = state

    # Tokenize task
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")

    # Discretize state for prompt (Pi0.5 format)
    state_discrete = ((state[0].clamp(-1, 1) + 1) / 2 * 255).int()
    state_str = " ".join(str(v.item()) for v in state_discrete)
    prompt = f"Task: {task}, State: {state_str};\nAction: "

    tokens = tokenizer(
        prompt,
        padding="max_length",
        max_length=200,
        truncation=True,
        return_tensors="pt",
    )

    batch = {
        "observation.images.base_0_rgb": image_224.to(device),
        "observation.images.left_wrist_0_rgb": image_224.to(device),
        "observation.state": state_padded.to(device),
        "observation.language.tokens": tokens["input_ids"].to(device),
        "observation.language.attention_mask": tokens["attention_mask"].bool().to(device),
    }
    return batch


def decode_actions(raw_actions, stats):
    """Convert model output actions back to LeRobot scale, then to sim radians."""
    actions = raw_actions[0, :, :6].cpu().numpy()  # (chunk_size, 6)

    # Unnormalize from MEAN_STD
    action_mean = np.array(stats["action"]["mean"])
    action_std = np.array(stats["action"]["std"])
    actions = actions * action_std + action_mean

    # Now in LeRobot degree-scale. Convert to radians for sim.
    actions_rad = np.radians(actions)
    return actions_rad


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", type=str, required=True)
    parser.add_argument("--task", type=str, default="pick up the cube and place it in the bin")
    parser.add_argument("--steps", type=int, default=200)
    parser.add_argument("--output", type=str, default="sim_eval.mp4")
    parser.add_argument("--device", type=str, default="cuda")
    args = parser.parse_args()

    import json
    with open(Path(__file__).parent / "norm_stats.json") as f:
        stats = json.load(f)

    # Load policy
    policy = load_policy(args.checkpoint, args.device)

    # Create sim
    env = SO100Env(task="so100_cube_to_bin", obs_type="so100_pixels_agent_pos")
    obs, info = env.reset()

    frames = []
    print(f"Running {args.steps} sim steps with task: '{args.task}'")

    chunk_actions = None
    chunk_idx = 0

    for step in range(args.steps):
        # Get new action chunk from policy every N steps
        if chunk_actions is None or chunk_idx >= len(chunk_actions):
            with torch.no_grad():
                batch = build_batch(obs, args.task, stats, args.device)
                action = policy.select_action(batch)
                chunk_actions = decode_actions(action.unsqueeze(0), stats)
                chunk_idx = 0

        # Apply one action from the chunk
        action = chunk_actions[chunk_idx]
        chunk_idx += 1

        # Normalize radians to sim's [-1, 1] action space
        joint_mins = np.array([-1.92, -3.32, -0.174, -1.66, -2.79, -0.174])
        joint_maxs = np.array([1.92, 0.174, 3.14, 1.66, 2.79, 1.75])
        sim_action = 2.0 * (action - joint_mins) / (joint_maxs - joint_mins) - 1.0
        sim_action = np.clip(sim_action, -1.0, 1.0)

        obs, reward, terminated, truncated, info = env.step(sim_action.astype(np.float32))

        frame = env.render()
        frames.append(frame)

        if step % 20 == 0:
            pos = obs["agent_pos"]
            print(f"  step {step:>3}: pos=[{pos[0]:.2f} {pos[1]:.2f} {pos[2]:.2f} {pos[3]:.2f} {pos[4]:.2f} {pos[5]:.3f}] reward={reward:.3f}")

        if terminated or truncated:
            print(f"Episode ended at step {step}")
            break

    # Save video
    imageio.mimsave(args.output, frames, fps=25)
    print(f"Saved {len(frames)} frames to {args.output}")


if __name__ == "__main__":
    main()