| | import dm_env |
| | from absl import logging |
| |
|
| | import rclpy |
| | from sensor_msgs.msg import Image, JointState |
| | from std_msgs.msg import Bool |
| | from std_msgs.msg import Int32 |
| | import numpy as np |
| | import threading |
| | import time |
| | |
| | import random |
| | from scipy.spatial.transform import Rotation |
| | from glob import glob |
| | import os |
| | import h5py |
| | import cv2 |
| |
|
| | class AnubisRobotEnv: |
| | def __init__(self, hz=20, max_timestep=1000, task_name='', num_rollout=1): |
| | rclpy.init() |
| | self._node = rclpy.create_node('anubis_robot_env_node') |
| | self._subscriber_bringup() |
| | print('ROS2 node created') |
| |
|
| | self.window = None |
| | self.start = False |
| | self.thread_done = False |
| | self.hz = hz |
| | self.action_counter = 0 |
| | self.num_rollout = num_rollout |
| | self.rollout_counter = 0 |
| |
|
| | self.lang_dict = { |
| | 'anubis_brush_to_pan' : 'insert the brush to the dustpan', |
| | 'anubis_carrot_to_bag' : 'pick up the carrot and put into the bag', |
| | 'anubis_towel_kirby' : 'take the towel off the kirby doll' |
| | } |
| | self.task_name = task_name |
| | self.instruction = self.lang_dict[self.task_name] |
| | self.data_list = glob(f'/home/rllab/workspace/jellyho/demo_collection/{self.task_name}/*.hdf5') |
| |
|
| | self.overlay_img = None |
| | self.max_timestep = max_timestep |
| |
|
| | self.init_action = JointState() |
| | self.init_action.position = [ |
| | 0.20620185010895048, |
| | 0.16183641523267392, |
| | 0.2277105000367078, |
| | -0.42093861525667453, |
| | 0.6546518510233503, |
| | -0.5770953981378887, |
| | 0.24739146627474096, |
| | -1.6, |
| | 0.21136149716403216, |
| | -0.16027684481842075, |
| | 0.21879985782478842, |
| | 0.6606782591766969, |
| | -0.428768621033297, |
| | 0.2340722378552696, |
| | -0.569975345900049, |
| | -1.6 |
| | ] |
| |
|
| | print('Initializing Anubis Robot Environment') |
| |
|
| | self.thread = PeriodicThread(1/self.hz, self.timer_callback) |
| | self.thread.start() |
| |
|
| | self.video_thread = PeriodicThread(1/30, self.video_timer_callback) |
| | self.video_thread.start() |
| |
|
| | self.timer_thread = threading.Thread(target=rclpy.spin, args=(self._node,), daemon=True) |
| | self.timer_thread.start() |
| | print('Threads started') |
| |
|
| | self.bringup_model() |
| | self.initialize() |
| | logging.set_verbosity(logging.INFO) |
| | logging.info('AnubisRobotEnv successfully initialized.') |
| |
|
| | def init_robot_pose(self, demo): |
| | print('Initializing robot pose', demo % len(self.data_list)) |
| | root = h5py.File(self.data_list[demo % len(self.data_list)], 'r') |
| | first_action = root['action']['eef_pose'][0] |
| | self.publish_action(first_action) |
| | |
| | def initialize(self): |
| | self.curr_timestep = 0 |
| | if self.window is None: |
| | from visualize_utils import window |
| | self.window = window('ENV Observation', video_path=f'{self.model_name}-{self.task_name}', video_fps=30, video_size=(640, 480), show=False) |
| | else: |
| | self.window.init_video() |
| | self.send_demo(self.rollout_counter) |
| | self.init_robot_pose(self.rollout_counter) |
| |
|
| | def reset(self): |
| | while not self.thread_done: |
| | time.sleep(0.01) |
| | continue |
| | self.thread_done = False |
| | return dm_env.restart(observation=self._observation()) |
| |
|
| | def bringup_model(self): |
| | raise NotImplementedError |
| | |
| | def inference(self): |
| | raise NotImplementedError |
| |
|
| | def ros_close(self): |
| | self.thread.stop() |
| | self.timer_thread.stop() |
| | self._node.destroy_node() |
| | rclpy.shutdown() |
| |
|
| | def _subscriber_bringup(self): |
| | ''' |
| | Note: This function creates all the subscribers \ |
| | for reading joint and gripper states. |
| | ''' |
| | |
| | self.obs = {} |
| | self.action = {} |
| |
|
| | |
| | |
| | self._node.create_subscription(Image, '/camera_center/camera/color/image_raw', self.agentview_image_callback, 10) |
| | self.obs['agentview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) |
| |
|
| | self._node.create_subscription(Image, '/camera_right/camera/color/image_raw', self.rightview_image_callback, 10) |
| | self.obs['rightview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) |
| |
|
| | self._node.create_subscription(Image, '/camera_left/camera/color/image_raw', self.leftview_image_callback, 10) |
| | self.obs['leftview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) |
| |
|
| | |
| | self._node.create_subscription(JointState, '/eef_pose', self.eef_pose_callback, 10) |
| | self.obs['eef_pose'] = np.zeros(shape=(20,), dtype=np.float64) |
| |
|
| | |
| | self.obs['language_instruction'] = '' |
| |
|
| | |
| | self._node.create_subscription(Bool, '/done', self.done_callback, 10) |
| |
|
| | self.demo_pub = self._node.create_publisher(Int32, '/demo', 10) |
| | self.action_pub = self._node.create_publisher(JointState, '/teleop/eef_pose', 10) |
| |
|
| | def send_demo(self, num): |
| | demo_msg = Int32() |
| | demo_msg.data = num |
| | self.demo_pub.publish(demo_msg) |
| |
|
| | |
| | def agentview_image_callback(self, msg): |
| | self.obs['agentview_image'] = np.reshape(msg.data, (480, 640, 3)) |
| |
|
| | def rightview_image_callback(self, msg): |
| | rightview = np.reshape(msg.data, (480, 640, 3)) |
| | self.obs['rightview_image'] = np.rot90(rightview, 2) |
| |
|
| | def leftview_image_callback(self, msg): |
| | self.obs['leftview_image'] = np.reshape(msg.data, (480, 640, 3)) |
| |
|
| | def eef_pose_callback(self, msg): |
| | recevied_data = np.array(msg.position) |
| | eef_pose_data = np.zeros(shape=(20,), dtype=np.float64) |
| | eef_pose_data[:3] = recevied_data[:3] |
| | eef_pose_data[3:9] = self.quat_to_6d(recevied_data[3:7], scalar_first=False) |
| | eef_pose_data[9] = recevied_data[7] |
| | eef_pose_data[10:13] = recevied_data[8:11] |
| | eef_pose_data[13:19] = self.quat_to_6d(recevied_data[11:15], scalar_first=False) |
| | eef_pose_data[19] = recevied_data[15] |
| | self.obs['eef_pose'] = eef_pose_data |
| |
|
| | def send_action(self, act): |
| | if self.start: |
| | action_msg = JointState() |
| | |
| | |
| | |
| | |
| | |
| | action_msg_data = np.zeros(16) |
| | action_msg_data[0:3] = act[0:3] |
| | action_msg_data[3:7] = self.sixd_to_quat(act[3:9]) |
| | action_msg_data[7] = act[9] |
| | action_msg_data[8:11] = act[10:13] |
| | action_msg_data[11:15] = self.sixd_to_quat(act[13:19]) |
| | action_msg_data[15] = act[19] |
| | action_msg.position = action_msg_data.astype(float).tolist() |
| | self.action_pub.publish(action_msg) |
| |
|
| | def publish_action(self, action): |
| | action_msg = JointState() |
| | |
| |
|
| | |
| | action = action.squeeze() |
| | action_msg_data = np.zeros(16) |
| | action_msg_data[0:3] = action[0:3] |
| | action_msg_data[3:7] = self.sixd_to_quat(action[3:9]) |
| | action_msg_data[7] = action[9] |
| | action_msg_data[8:11] = action[10:13] |
| | action_msg_data[11:15] = self.sixd_to_quat(action[13:19]) |
| | action_msg_data[15] = action[19] |
| | action_msg.position = action_msg_data.astype(float).tolist() |
| | self.action_pub.publish(action_msg) |
| |
|
| | def done_callback(self, msg): |
| | if not self.start: |
| | print('Inference & Video Recording Start') |
| | self.start = True |
| | self.window.video_start() |
| | else: |
| | self.start = False |
| | self.action_counter = 0 |
| | self.rollout_counter += 1 |
| | if self.window.video_recording: |
| | self.window.video_stop() |
| | self.initialize() |
| | print('Next Inference Ready') |
| |
|
| | def timer_callback(self): |
| | if self.start: |
| | self.inference() |
| | self.curr_timestep += 1 |
| | if self.curr_timestep >= self.max_timestep: |
| | print("Max timestep reached, resetting environment.") |
| | self.start = False |
| | if self.window.video_recording: |
| | self.window.video_stop() |
| | self.rollout_counter += 1 |
| | self.action_counter = 0 |
| | self.curr_timestep = 0 |
| | self.initialize() |
| | self.thread_done = True |
| |
|
| | def video_timer_callback(self): |
| | if self.start and self.window.video_recording: |
| | self.window.video_write() |
| |
|
| | def quat_to_6d(self, quat, scalar_first=False): |
| | r = Rotation.from_quat(quat, scalar_first=scalar_first) |
| | mat = r.as_matrix() |
| | return mat[:, :2].flatten() |
| | |
| | def sixd_to_quat(self, sixd, scalar_first=False): |
| | mat = np.zeros((3, 3)) |
| | mat[:, :2] = sixd.reshape(3, 2) |
| | mat[:, 2] = np.cross(mat[:, 0], mat[:, 1]) |
| | r = Rotation.from_matrix(mat) |
| | return r.as_quat(scalar_first=scalar_first) |
| | |
| | def ros_close(self): |
| | if self.window.video_recording: |
| | self.window.video_stop() |
| | self.thread.stop() |
| | self.video_thread.stop() |
| | self.timer_thread.stop() |
| | self._node.destroy_node() |
| | rclpy.shutdown() |
| |
|
| | class PeriodicThread(threading.Thread): |
| | def __init__(self, interval, function, *args, **kwargs): |
| | super().__init__() |
| | self.interval = interval |
| | self.function = function |
| | self.args = args |
| | self.kwargs = kwargs |
| | self.stop_event = threading.Event() |
| | self._lock = threading.Lock() |
| |
|
| | def run(self): |
| | while not self.stop_event.is_set(): |
| | start_time = time.time() |
| | self.function(*self.args, **self.kwargs) |
| | elapsed_time = time.time() - start_time |
| | sleep_time = max(0, self.interval - elapsed_time) |
| | time.sleep(sleep_time) |
| |
|
| | def stop(self): |
| | self.stop_event.set() |
| |
|
| | def change_period(self, new_interval): |
| | with self._lock: |
| | self.interval = new_interval |