pi05-so100-diverse / eval_kitchen.py
bot
Restore all project files from original repo
9ad6280
#!/usr/bin/env python3
"""
Evaluate Pi0.5 checkpoints in the RoboCasa kitchen sim.
Compares base model vs finetuned model side by side.
Runs on CPU only (GPU is used by training).
Usage:
python eval_kitchen.py --checkpoint /mnt/hdd/pi05-training/full_run/checkpoints/004000/pretrained_model
python eval_kitchen.py --checkpoint lerobot/pi05_base # base model comparison
python eval_kitchen.py --compare # run both and save side-by-side
"""
import argparse
import json
import os
import sys
from pathlib import Path
# EGL rendering for headless MuJoCo
os.environ["MUJOCO_GL"] = "egl"
import imageio
import numpy as np
import torch
sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, str(Path.home() / "lerobot" / "src"))
sys.path.insert(0, "/mnt/hdd/pi05-training/robocasa_test")
from so100_kitchen_env import SO100KitchenEnv
def load_policy(checkpoint_path, device="cuda"):
"""Load Pi0.5 policy."""
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
print(f"Loading policy from {checkpoint_path} ({device})...")
policy = PI05Policy.from_pretrained(str(checkpoint_path))
policy = policy.to(device)
policy.eval()
return policy
def build_batch(env_obs, camera_image, task, stats, device="cuda"):
"""Convert kitchen env observation to Pi0.5 batch format."""
import torchvision.transforms.functional as TF
# Image: (H, W, 3) uint8 -> (1, 3, 224, 224) float32
image = torch.from_numpy(camera_image).permute(2, 0, 1).float() / 255.0
image = image.unsqueeze(0)
image_224 = TF.resize(image, [224, 224], antialias=True)
# ImageNet normalization
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
image_224 = (image_224 - mean) / std
# State: joint positions in radians -> degrees (LeRobot scale), then normalize
joint_pos = env_obs["joint_pos"]
state_degrees = np.degrees(joint_pos)
state = torch.tensor(state_degrees, dtype=torch.float32).unsqueeze(0)
state_mean = torch.tensor(stats["observation.state"]["mean"], dtype=torch.float32)
state_std = torch.tensor(stats["observation.state"]["std"], dtype=torch.float32)
state = (state - state_mean) / (state_std + 1e-8)
# Pad to 32 dims
state_padded = torch.zeros(1, 32)
state_padded[:, :6] = state
# Tokenize
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
state_discrete = ((state[0].clamp(-1, 1) + 1) / 2 * 255).int()
state_str = " ".join(str(v.item()) for v in state_discrete)
prompt = f"Task: {task}, State: {state_str};\nAction: "
tokens = tokenizer(
prompt, padding="max_length", max_length=200,
truncation=True, return_tensors="pt",
)
return {
"observation.images.base_0_rgb": image_224.to(device),
"observation.images.left_wrist_0_rgb": image_224.to(device),
"observation.state": state_padded.to(device),
"observation.language.tokens": tokens["input_ids"].to(device),
"observation.language.attention_mask": tokens["attention_mask"].bool().to(device),
}
def decode_actions(raw_actions, stats):
"""Convert model output to joint angle radians."""
actions = raw_actions[0, :, :6].cpu().numpy()
action_mean = np.array(stats["action"]["mean"])
action_std = np.array(stats["action"]["std"])
actions = actions * action_std + action_mean
return np.radians(actions)
def run_episode(policy, env, task, stats, num_steps=200, camera="robot_workspace", show_live=True):
"""Run one episode, return frames and joint trajectories."""
obs = env.reset()
frames = []
joint_history = []
chunk_actions = None
chunk_idx = 0
for step in range(num_steps):
if chunk_actions is None or chunk_idx >= len(chunk_actions):
camera_image = env.render(camera)
with torch.no_grad():
batch = build_batch(obs, camera_image, task, stats, device=next(policy.parameters()).device)
action = policy.select_action(batch)
chunk_actions = decode_actions(action.unsqueeze(0), stats)
chunk_idx = 0
action = chunk_actions[chunk_idx]
chunk_idx += 1
obs, reward, done, info = env.step(action)
frame = env.render(camera)
frames.append(frame)
joint_history.append(obs["joint_pos"].copy())
# Live display via cv2 (static camera)
if show_live:
try:
import cv2
cv2.imshow("SO-100 Kitchen Sim", cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
if cv2.waitKey(1) & 0xFF == ord('q'):
print("Quit by user")
break
except Exception:
pass
if step % 25 == 0:
pos = obs["joint_pos"]
print(f" step {step:>3}: joints=[{pos[0]:.2f} {pos[1]:.2f} {pos[2]:.2f} {pos[3]:.2f} {pos[4]:.2f} {pos[5]:.3f}]")
return frames, np.array(joint_history)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--task", type=str, default="pick up the mug and place it on the plate")
parser.add_argument("--steps", type=int, default=200)
parser.add_argument("--output-dir", type=str, default="/mnt/hdd/pi05-training/eval_kitchen")
parser.add_argument("--compare", action="store_true", help="Run base vs finetuned comparison")
parser.add_argument("--viewer", action="store_true", help="Use MuJoCo interactive viewer (mouse orbit/pan/zoom)")
parser.add_argument("--finetuned-checkpoint", type=str,
default="/mnt/hdd/pi05-training/full_run/checkpoints/004000/pretrained_model")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
with open(Path(__file__).parent / "norm_stats.json") as f:
stats = json.load(f)
env = SO100KitchenEnv()
if args.viewer:
# Interactive MuJoCo viewer with mouse controls
import mujoco.viewer
import time as _time
policy = load_policy(args.checkpoint or "lerobot/pi05_base")
obs = env.reset()
chunk_actions = None
chunk_idx = 0
device = next(policy.parameters()).device
print(f"Launching interactive viewer. Task: '{args.task}'")
print("Mouse: Left=rotate, Right=pan, Scroll=zoom")
print("Close window to exit.")
viewer = mujoco.viewer.launch_passive(env.model, env.data)
step = 0
while viewer.is_running():
# Get action from policy
if chunk_actions is None or chunk_idx >= len(chunk_actions):
camera_image = env.render("overview")
with torch.no_grad():
batch = build_batch(obs, camera_image, args.task, stats, device=device)
action = policy.select_action(batch)
chunk_actions = decode_actions(action.unsqueeze(0), stats)
chunk_idx = 0
act = chunk_actions[chunk_idx]
chunk_idx += 1
# Apply action to actuators
from so100_kitchen_env import JOINT_NAMES
for i, name in enumerate(JOINT_NAMES):
aid = env.actuator_ids.get(name)
if aid is not None:
env.data.ctrl[aid] = act[i]
# Step physics
mujoco.mj_step(env.model, env.data)
viewer.sync()
# Update obs
joint_pos = np.array([env.data.qpos[env.model.jnt_qposadr[env.joint_ids[n]]] for n in JOINT_NAMES])
obs = {"joint_pos": joint_pos}
step += 1
if step % 50 == 0:
print(f" step {step}: joints=[{' '.join(f'{j:.2f}' for j in joint_pos)}]")
_time.sleep(0.02) # ~50Hz
viewer.close()
elif args.compare:
# Run both base and finetuned
print("=== BASE MODEL ===")
base_policy = load_policy("lerobot/pi05_base")
base_frames, base_joints = run_episode(base_policy, env, args.task, stats, args.steps)
del base_policy
print("\n=== FINETUNED MODEL ===")
ft_policy = load_policy(args.finetuned_checkpoint)
ft_frames, ft_joints = run_episode(ft_policy, env, args.task, stats, args.steps)
del ft_policy
# Save videos
imageio.mimsave(f"{args.output_dir}/base_model.mp4", base_frames, fps=25)
imageio.mimsave(f"{args.output_dir}/finetuned_model.mp4", ft_frames, fps=25)
# Save side-by-side frames at key timesteps
for t in [0, 50, 100, 150, 199]:
if t < len(base_frames) and t < len(ft_frames):
combined = np.concatenate([base_frames[t], ft_frames[t]], axis=1)
imageio.imwrite(f"{args.output_dir}/compare_step_{t:03d}.png", combined)
# Print joint trajectory summary
print("\n=== COMPARISON ===")
print(f"Base model - joint range: {base_joints.min(axis=0)} to {base_joints.max(axis=0)}")
print(f"Finetuned - joint range: {ft_joints.min(axis=0)} to {ft_joints.max(axis=0)}")
print(f"Base model - total motion: {np.abs(np.diff(base_joints, axis=0)).sum():.2f} rad")
print(f"Finetuned - total motion: {np.abs(np.diff(ft_joints, axis=0)).sum():.2f} rad")
print(f"\nSaved to {args.output_dir}/")
elif args.checkpoint:
policy = load_policy(args.checkpoint)
frames, joints = run_episode(policy, env, args.task, stats, args.steps)
name = Path(args.checkpoint).parent.name if "checkpoint" in args.checkpoint else "model"
imageio.mimsave(f"{args.output_dir}/{name}.mp4", frames, fps=25)
for t in [0, len(frames)//2, len(frames)-1]:
imageio.imwrite(f"{args.output_dir}/{name}_step_{t:03d}.png", frames[t])
print(f"Saved {len(frames)} frames to {args.output_dir}/")
else:
print("Specify --checkpoint or --compare")
if __name__ == "__main__":
main()