| | import rclpy |
| | from sensor_msgs.msg import JointState |
| | import numpy as np |
| | import time |
| | import h5py |
| | from real_robot_env import AnubisRobotEnv |
| | import cv2 |
| | from tqdm import tqdm |
| | import os |
| | from scripts.agilex_model import create_model |
| | import random |
| | from PIL import Image, ImageDraw, ImageFont |
| | import yaml |
| | import torch |
| |
|
| | class RDTInferenceRobotEnv(AnubisRobotEnv): |
| | def __init__(self, hz=20, max_timestep=500, task_name='', num_rollout=1, model_name='1st'): |
| | self.model_name = model_name |
| | self.checkpoint = f'/home/rllab/workspace/jellyho/RoboticsDiffusionTransformer/checkpoints/{task_name}' |
| | super().__init__(hz=hz, max_timestep=max_timestep, task_name=task_name, num_rollout=num_rollout) |
| | self.last_front_img = None |
| | self.right_wrist_img = None |
| | self.left_wrist_img = None |
| |
|
| | def bringup_model(self): |
| | with open('configs/base.yaml', "r") as fp: |
| | config = yaml.safe_load(fp) |
| | self.model = create_model( |
| | args=config, |
| | dtype=torch.bfloat16, |
| | pretrained=self.checkpoint, |
| | pretrained_text_encoder_name_or_path="google/t5-v1_1-xxl", |
| | pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384", |
| | control_frequency=20 |
| | ) |
| | print('model loaded') |
| |
|
| | def inference(self): |
| | proprio = self.obs['eef_pose'] |
| | if self.action_counter % 64 == 0: |
| | front_img = self.obs['agentview_image'] |
| | right_wrist_img = self.obs['rightview_image'] |
| | left_wrist_img = self.obs['leftview_image'] |
| | image_arrs = [ |
| | self.last_front_img, |
| | self.last_right_wrist_img, |
| | self.last_left_wrist_img, |
| | front_img, |
| | right_wrist_img, |
| | left_wrist_img |
| | ] |
| | images = [Image.fromarray(arr) if arr is not None else None for arr in image_arrs] |
| | proprio = torch.tensor(proprio).unsqueeze(0) |
| | with torch.inference_mode(): |
| | self.actions = self.model.step( |
| | proprio=proprio, |
| | images=images, |
| | instruction=self.instruction |
| | ).squeeze(0).cpu().numpy() |
| | idx = self.action_counter % 64 |
| | act = self.actions[idx] |
| | self.action_counter += 1 |
| | self.last_front_img = self.obs['agentview_image'] |
| | self.last_right_wrist_img = self.obs['rightview_image'] |
| | self.last_left_wrist_img = self.obs['leftview_image'] |
| | self.send_action(act) |
| |
|
| | if __name__ == '__main__': |
| | task_name = 'anubis_towel_kirby' |
| | rollout_num = 20 |
| | hz = 20 |
| | model_name = 'rdt' |
| | |
| | node = RDTInferenceRobotEnv( |
| | hz=hz, |
| | max_timestep=800, |
| | task_name=task_name, |
| | num_rollout=rollout_num, |
| | model_name=model_name |
| | ) |
| | while node.rollout_counter < rollout_num: |
| | try: |
| | img = cv2.cvtColor(node.obs['agentview_image'], cv2.COLOR_RGB2BGR) |
| | if node.start: |
| | node.window.show(img, text=node.instruction) |
| | else: |
| | node.window.show(img, overlay_img=node.overlay_img, text=node.instruction) |
| | node.last_front_img = node.obs['agentview_image'] |
| | node.last_right_wrist_img = node.obs['rightview_image'] |
| | node.last_left_wrist_img = node.obs['leftview_image'] |
| | except KeyboardInterrupt: |
| | node.ros_close() |
| | |
| | except Exception as e: |
| | print(f"An error occurred: {e}") |
| | node.ros_close() |