File size: 10,106 Bytes
9ad6280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
#!/usr/bin/env python3
"""
Evaluate Pi0.5 checkpoints in the RoboCasa kitchen sim.
Compares base model vs finetuned model side by side.

Runs on CPU only (GPU is used by training).

Usage:
  python eval_kitchen.py --checkpoint /mnt/hdd/pi05-training/full_run/checkpoints/004000/pretrained_model
  python eval_kitchen.py --checkpoint lerobot/pi05_base  # base model comparison
  python eval_kitchen.py --compare  # run both and save side-by-side
"""

import argparse
import json
import os
import sys
from pathlib import Path

# EGL rendering for headless MuJoCo
os.environ["MUJOCO_GL"] = "egl"

import imageio
import numpy as np
import torch

sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, str(Path.home() / "lerobot" / "src"))
sys.path.insert(0, "/mnt/hdd/pi05-training/robocasa_test")

from so100_kitchen_env import SO100KitchenEnv


def load_policy(checkpoint_path, device="cuda"):
    """Load Pi0.5 policy."""
    from lerobot.policies.pi05.modeling_pi05 import PI05Policy
    print(f"Loading policy from {checkpoint_path} ({device})...")
    policy = PI05Policy.from_pretrained(str(checkpoint_path))
    policy = policy.to(device)
    policy.eval()
    return policy


def build_batch(env_obs, camera_image, task, stats, device="cuda"):
    """Convert kitchen env observation to Pi0.5 batch format."""
    import torchvision.transforms.functional as TF

    # Image: (H, W, 3) uint8 -> (1, 3, 224, 224) float32
    image = torch.from_numpy(camera_image).permute(2, 0, 1).float() / 255.0
    image = image.unsqueeze(0)
    image_224 = TF.resize(image, [224, 224], antialias=True)

    # ImageNet normalization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    image_224 = (image_224 - mean) / std

    # State: joint positions in radians -> degrees (LeRobot scale), then normalize
    joint_pos = env_obs["joint_pos"]
    state_degrees = np.degrees(joint_pos)
    state = torch.tensor(state_degrees, dtype=torch.float32).unsqueeze(0)

    state_mean = torch.tensor(stats["observation.state"]["mean"], dtype=torch.float32)
    state_std = torch.tensor(stats["observation.state"]["std"], dtype=torch.float32)
    state = (state - state_mean) / (state_std + 1e-8)

    # Pad to 32 dims
    state_padded = torch.zeros(1, 32)
    state_padded[:, :6] = state

    # Tokenize
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")

    state_discrete = ((state[0].clamp(-1, 1) + 1) / 2 * 255).int()
    state_str = " ".join(str(v.item()) for v in state_discrete)
    prompt = f"Task: {task}, State: {state_str};\nAction: "

    tokens = tokenizer(
        prompt, padding="max_length", max_length=200,
        truncation=True, return_tensors="pt",
    )

    return {
        "observation.images.base_0_rgb": image_224.to(device),
        "observation.images.left_wrist_0_rgb": image_224.to(device),
        "observation.state": state_padded.to(device),
        "observation.language.tokens": tokens["input_ids"].to(device),
        "observation.language.attention_mask": tokens["attention_mask"].bool().to(device),
    }


def decode_actions(raw_actions, stats):
    """Convert model output to joint angle radians."""
    actions = raw_actions[0, :, :6].cpu().numpy()
    action_mean = np.array(stats["action"]["mean"])
    action_std = np.array(stats["action"]["std"])
    actions = actions * action_std + action_mean
    return np.radians(actions)


def run_episode(policy, env, task, stats, num_steps=200, camera="robot_workspace", show_live=True):
    """Run one episode, return frames and joint trajectories."""
    obs = env.reset()
    frames = []
    joint_history = []
    chunk_actions = None
    chunk_idx = 0

    for step in range(num_steps):
        if chunk_actions is None or chunk_idx >= len(chunk_actions):
            camera_image = env.render(camera)
            with torch.no_grad():
                batch = build_batch(obs, camera_image, task, stats, device=next(policy.parameters()).device)
                action = policy.select_action(batch)
                chunk_actions = decode_actions(action.unsqueeze(0), stats)
                chunk_idx = 0

        action = chunk_actions[chunk_idx]
        chunk_idx += 1

        obs, reward, done, info = env.step(action)
        frame = env.render(camera)
        frames.append(frame)
        joint_history.append(obs["joint_pos"].copy())

        # Live display via cv2 (static camera)
        if show_live:
            try:
                import cv2
                cv2.imshow("SO-100 Kitchen Sim", cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    print("Quit by user")
                    break
            except Exception:
                pass

        if step % 25 == 0:
            pos = obs["joint_pos"]
            print(f"  step {step:>3}: joints=[{pos[0]:.2f} {pos[1]:.2f} {pos[2]:.2f} {pos[3]:.2f} {pos[4]:.2f} {pos[5]:.3f}]")

    return frames, np.array(joint_history)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", type=str, default=None)
    parser.add_argument("--task", type=str, default="pick up the mug and place it on the plate")
    parser.add_argument("--steps", type=int, default=200)
    parser.add_argument("--output-dir", type=str, default="/mnt/hdd/pi05-training/eval_kitchen")
    parser.add_argument("--compare", action="store_true", help="Run base vs finetuned comparison")
    parser.add_argument("--viewer", action="store_true", help="Use MuJoCo interactive viewer (mouse orbit/pan/zoom)")
    parser.add_argument("--finetuned-checkpoint", type=str,
                        default="/mnt/hdd/pi05-training/full_run/checkpoints/004000/pretrained_model")
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    with open(Path(__file__).parent / "norm_stats.json") as f:
        stats = json.load(f)

    env = SO100KitchenEnv()

    if args.viewer:
        # Interactive MuJoCo viewer with mouse controls
        import mujoco.viewer
        import time as _time
        policy = load_policy(args.checkpoint or "lerobot/pi05_base")
        obs = env.reset()
        chunk_actions = None
        chunk_idx = 0
        device = next(policy.parameters()).device

        print(f"Launching interactive viewer. Task: '{args.task}'")
        print("Mouse: Left=rotate, Right=pan, Scroll=zoom")
        print("Close window to exit.")

        viewer = mujoco.viewer.launch_passive(env.model, env.data)
        step = 0
        while viewer.is_running():
            # Get action from policy
            if chunk_actions is None or chunk_idx >= len(chunk_actions):
                camera_image = env.render("overview")
                with torch.no_grad():
                    batch = build_batch(obs, camera_image, args.task, stats, device=device)
                    action = policy.select_action(batch)
                    chunk_actions = decode_actions(action.unsqueeze(0), stats)
                    chunk_idx = 0

            act = chunk_actions[chunk_idx]
            chunk_idx += 1

            # Apply action to actuators
            from so100_kitchen_env import JOINT_NAMES
            for i, name in enumerate(JOINT_NAMES):
                aid = env.actuator_ids.get(name)
                if aid is not None:
                    env.data.ctrl[aid] = act[i]

            # Step physics
            mujoco.mj_step(env.model, env.data)
            viewer.sync()

            # Update obs
            joint_pos = np.array([env.data.qpos[env.model.jnt_qposadr[env.joint_ids[n]]] for n in JOINT_NAMES])
            obs = {"joint_pos": joint_pos}

            step += 1
            if step % 50 == 0:
                print(f"  step {step}: joints=[{' '.join(f'{j:.2f}' for j in joint_pos)}]")

            _time.sleep(0.02)  # ~50Hz

        viewer.close()

    elif args.compare:
        # Run both base and finetuned
        print("=== BASE MODEL ===")
        base_policy = load_policy("lerobot/pi05_base")
        base_frames, base_joints = run_episode(base_policy, env, args.task, stats, args.steps)
        del base_policy

        print("\n=== FINETUNED MODEL ===")
        ft_policy = load_policy(args.finetuned_checkpoint)
        ft_frames, ft_joints = run_episode(ft_policy, env, args.task, stats, args.steps)
        del ft_policy

        # Save videos
        imageio.mimsave(f"{args.output_dir}/base_model.mp4", base_frames, fps=25)
        imageio.mimsave(f"{args.output_dir}/finetuned_model.mp4", ft_frames, fps=25)

        # Save side-by-side frames at key timesteps
        for t in [0, 50, 100, 150, 199]:
            if t < len(base_frames) and t < len(ft_frames):
                combined = np.concatenate([base_frames[t], ft_frames[t]], axis=1)
                imageio.imwrite(f"{args.output_dir}/compare_step_{t:03d}.png", combined)

        # Print joint trajectory summary
        print("\n=== COMPARISON ===")
        print(f"Base model - joint range: {base_joints.min(axis=0)} to {base_joints.max(axis=0)}")
        print(f"Finetuned  - joint range: {ft_joints.min(axis=0)} to {ft_joints.max(axis=0)}")
        print(f"Base model - total motion: {np.abs(np.diff(base_joints, axis=0)).sum():.2f} rad")
        print(f"Finetuned  - total motion: {np.abs(np.diff(ft_joints, axis=0)).sum():.2f} rad")

        print(f"\nSaved to {args.output_dir}/")

    elif args.checkpoint:
        policy = load_policy(args.checkpoint)
        frames, joints = run_episode(policy, env, args.task, stats, args.steps)

        name = Path(args.checkpoint).parent.name if "checkpoint" in args.checkpoint else "model"
        imageio.mimsave(f"{args.output_dir}/{name}.mp4", frames, fps=25)

        for t in [0, len(frames)//2, len(frames)-1]:
            imageio.imwrite(f"{args.output_dir}/{name}_step_{t:03d}.png", frames[t])

        print(f"Saved {len(frames)} frames to {args.output_dir}/")
    else:
        print("Specify --checkpoint or --compare")


if __name__ == "__main__":
    main()