| import os |
| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional, Union |
| import draccus |
| import numpy as np |
| import tqdm |
|
|
| from PIL import Image |
| import torch |
|
|
| import tabletop |
| from dm_env import StepType as st |
| import imageio |
| import time |
| import yaml |
| from scripts.agilex_model import create_model |
| import random |
| from PIL import Image, ImageDraw, ImageFont |
|
|
| def set_seed_everywhere(seed: int): |
| """Sets the random seed for Python, NumPy, and PyTorch functions.""" |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
| def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, folder=None, subtitile=None): |
| """Saves an MP4 replay of an episode.""" |
| if folder is None: |
| rollout_dir = f"./rollouts/{DATE}" |
| else: |
| rollout_dir = f"./rollouts/{DATE}/{folder}/videos" |
| os.makedirs(rollout_dir, exist_ok=True) |
|
|
| processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] |
| mp4_path = f"{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--task={processed_task_description}.mp4" |
| video_writer = imageio.get_writer(mp4_path, fps=30) |
|
|
| for img in rollout_images: |
| if subtitile: |
| pil_img = Image.fromarray(img) |
| draw = ImageDraw.Draw(pil_img) |
| font = ImageFont.load_default() |
| |
| font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" |
| font = ImageFont.truetype(font_path, size=24) |
|
|
| |
| bbox = draw.textbbox((0, 0), subtitile, font=font) |
| text_width = bbox[2] - bbox[0] |
| text_height = bbox[3] - bbox[1] |
| text_x = (pil_img.width - text_width) // 2 |
| text_y = pil_img.height - text_height - 10 |
|
|
| draw.text((text_x, text_y), subtitile, font=font, fill="white") |
| img = np.array(pil_img) |
|
|
| video_writer.append_data(img) |
|
|
| video_writer.close() |
| print(f"Saved rollout MP4 at path {mp4_path}") |
| if log_file is not None: |
| log_file.write(f"Saved rollout MP4 at path {mp4_path}\n") |
| return mp4_path |
|
|
| DATE = time.strftime("%Y_%m_%d") |
| DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| os.environ["MUJOCO_GL"] = "egl" |
|
|
| @dataclass |
| class Config: |
| checkpoint: Union[str, Path] = "" |
| twin_or_sing: str = 'TwinVLA' |
| task_name: str = "aloha_dish_drainer" |
| action_space: str = "ee_6d_pos" |
| num_steps_wait: int = 10 |
| num_trials_per_task: int = 5 |
| action_len: int = 20 |
| benchmark: bool = True |
|
|
| run_id_note: Optional[str] = None |
| seed: int = 48 |
|
|
| @draccus.wrap() |
| def eval_tabletop(cfg: Config) -> None: |
| set_seed_everywhere(cfg.seed) |
| unnorm_key = cfg.task_name |
| with open('configs/base.yaml', "r") as fp: |
| config = yaml.safe_load(fp) |
| |
| model = create_model( |
| args=config, |
| dtype=torch.bfloat16, |
| pretrained=cfg.checkpoint, |
| pretrained_text_encoder_name_or_path="google/t5-v1_1-xxl", |
| pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384", |
| control_frequency=20 |
| ) |
| |
|
|
| env = tabletop.env(cfg.task_name, cfg.action_space) |
| highest_rewards = [] |
| episode_returns = [] |
| for rollout_id in tqdm.tqdm(range(cfg.num_trials_per_task)): |
| np.random.seed(cfg.seed + rollout_id) |
| ts = env.reset() |
| if cfg.benchmark: |
| ts = env.task.benchmark_init(env.physics, rollout_id) |
| action_counter = 0 |
| replay_images = [] |
| rewards = [] |
| last_front_img = ts.observation['images']['back'] |
| last_right_wrist_img = ts.observation['images']['wrist_right'] |
| last_left_wrist_img = ts.observation['images']['wrist_left'] |
| with torch.inference_mode(): |
| while True: |
| obs = ts.observation |
| replay_images.append(obs['images']['back']) |
| if action_counter == 0: |
| front_img = obs['images']['back'] |
| right_wrist_img = obs['images']['wrist_right'] |
| left_wrist_img = obs['images']['wrist_left'] |
| image_arrs = [ |
| last_front_img, |
| last_right_wrist_img, |
| last_left_wrist_img, |
| front_img, |
| right_wrist_img, |
| left_wrist_img |
| ] |
| images = [Image.fromarray(arr) if arr is not None else None for arr in image_arrs] |
| proprio = torch.tensor(obs['ee_6d_pos']).unsqueeze(0) |
| actions = model.step( |
| proprio=proprio, |
| images=images, |
| instruction=obs['language_instruction'] |
| ).squeeze(0).cpu().numpy() |
| |
| action = actions[action_counter] |
| ts = env.step(action) |
| rewards.append(ts.reward) |
| action_counter += 1 |
| if action_counter == cfg.action_len: |
| action_counter = 0 |
| if ts.reward == env.task.max_reward or ts.step_type==st.LAST: |
| break |
| last_front_img = ts.observation['images']['back'] |
| last_right_wrist_img = ts.observation['images']['wrist_right'] |
| last_left_wrist_img = ts.observation['images']['wrist_left'] |
|
|
| rewards = np.array(rewards) |
| episode_return = np.sum(rewards[rewards!=None]) |
| episode_returns.append(episode_return) |
| episode_highest_reward = np.max(rewards) |
| highest_rewards.append(episode_highest_reward) |
| env_max_reward = env.task.max_reward |
| |
| save_rollout_video( |
| replay_images, rollout_id, success=episode_highest_reward==env_max_reward, task_description=cfg.task_name, folder=f"{cfg.checkpoint.split('/')[-1]}-{cfg.task_name}-{cfg.seed}", subtitile=obs['language_instruction'] |
| ) |
| replay_images.clear() |
|
|
| success_rate = np.mean(np.array(highest_rewards) == env_max_reward) |
| avg_return = np.mean(episode_returns) |
| summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n' |
| for r in range(env_max_reward+1): |
| more_or_equal_r = (np.array(highest_rewards) >= r).sum() |
| more_or_equal_r_rate = more_or_equal_r / cfg.num_trials_per_task |
| summary_str += f'Reward >= {r}: {more_or_equal_r}/{cfg.num_trials_per_task} = {more_or_equal_r_rate*100}%\n' |
| |
| log_dir = Path('rollouts') / DATE / f"{cfg.checkpoint.split('/')[-1]}-{cfg.task_name}-{cfg.seed}" |
| log_dir.mkdir(parents=True, exist_ok=True) |
| summary_file = log_dir / "summary.txt" |
| with summary_file.open("w") as f: |
| f.write(summary_str) |
|
|
| if __name__ == "__main__": |
| eval_tabletop() |