File size: 4,899 Bytes
30747b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | from rlbench.backend.task import Task
from rlbench.backend.scene import DemoError
from rlbench.observation_config import ObservationConfig
from pyrep import PyRep
from pyrep.robots.arms.panda import Panda
from pyrep.robots.end_effectors.panda_gripper import PandaGripper
from rlbench.backend.const import TTT_FILE
from rlbench.backend.scene import Scene
from rlbench.backend.utils import task_file_to_task_class
from rlbench.backend.task import TASKS_PATH
from rlbench.backend.robot import Robot
from rlbench.backend.robot import UnimanualRobot
import numpy as np
import os
import argparse
DEMO_ATTEMPTS = 5
MAX_VARIATIONS = 100
class TaskValidationError(Exception):
pass
def task_smoke(task: Task, scene: Scene, variation=-1, demos=4, success=0.50,
max_variations=3, test_demos=True):
# -1 variations for all.
print('Running task validator on task: %s' % task.get_name())
# Loading
scene.load(task)
# Number of variations
variation_count = task.variation_count()
if variation_count < 0:
raise TaskValidationError(
"The method 'variation_count' should return a number > 0.")
if variation_count > MAX_VARIATIONS:
raise TaskValidationError(
"This task had %d variations. Currently the limit is set to %d" %
(variation_count, MAX_VARIATIONS))
# Base rotation bounds
base_pos, base_ori = task.base_rotation_bounds()
if len(base_pos) != 3 or len(base_ori) != 3:
raise TaskValidationError(
"The method 'base_rotation_bounds' should return a tuple "
"containing a list of floats.")
# Boundary root
root = task.boundary_root()
if not root.still_exists():
raise TaskValidationError(
"The method 'boundary_root' should return a Dummy that is the root "
"of the task.")
def variation_smoke(i):
print('Running task validator on variation: %d' % i)
attempt_result = False
failed_demos = 0
for j in range(DEMO_ATTEMPTS):
failed_demos = run_demos(i)
attempt_result = (failed_demos / float(demos) <= 1. - success)
if attempt_result:
break
else:
print('Failed on attempt %d. Trying again...' % j)
# Make sure we don't fail too often
if not attempt_result:
raise TaskValidationError(
"Too many failed demo runs. %d of %d demos failed." % (
failed_demos, demos))
else:
print('Variation %d of task %s is good!' % (i, task.get_name()))
if test_demos:
print('%d of %d demos were successful.' % (
demos - failed_demos, demos))
def run_demos(variation_num):
fails = 0
for dr in range(demos):
try:
scene.reset()
desc = scene.init_episode(variation_num, max_attempts=10)
if not isinstance(desc, list) or len(desc) <= 0:
raise TaskValidationError(
"The method 'init_variation' should return a list of "
"string descriptions.")
if test_demos:
demo = scene.get_demo(record=True)
assert len(demo) > 0
except DemoError as e:
fails += 1
print(e)
except Exception as e:
# TODO: check that we don't fall through all of these cases
fails += 1
print(e)
return fails
variations_to_test = [variation]
if variation < 0:
variations_to_test = list(range(
np.minimum(variation_count, max_variations)))
# Task set-up
scene.init_task()
[variation_smoke(i) for i in variations_to_test]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("task", help="The task file to test.")
args = parser.parse_args()
python_file = os.path.join(TASKS_PATH, args.task)
if not os.path.isfile(python_file):
raise RuntimeError('Could not find the task file: %s' % python_file)
task_class = task_file_to_task_class(args.task)
DIR_PATH = os.path.dirname(os.path.abspath(__file__))
sim = PyRep()
ttt_file = os.path.join(
DIR_PATH, '..', 'rlbench', TTT_FILE)
sim.launch(ttt_file, headless=True)
sim.step_ui()
sim.set_simulation_timestep(0.005)
sim.step_ui()
sim.start()
robot = UnimanualRobot(Panda(), PandaGripper())
active_task = task_class(sim, robot)
obs = ObservationConfig()
obs.set_all(False)
scene = Scene(sim, robot, obs)
try:
task_smoke(active_task, scene, variation=2)
except TaskValidationError as e:
sim.shutdown()
raise e
sim.shutdown()
print('Validation successful!')
|