File size: 1,145 Bytes
a32fcea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import unittest
from rlbench.tasks import (FS10_V1, FS25_V1, FS50_V1, FS95_V1, MT15_V1,
                           MT30_V1, MT55_V1, MT100_V1)


FS_V1 = [
    (FS10_V1, 10, 5),
    (FS25_V1, 25, 5),
    (FS50_V1, 50, 5),
    (FS95_V1, 95, 5)]

MT_V1 = [
    (MT15_V1, 15),
    (MT30_V1, 30),
    (MT55_V1, 55),
    (MT100_V1, 100)]


class TestTaskSet(unittest.TestCase):

    def test_fs_v1(self):
        for ts, train, test in FS_V1:
            with self.subTest(task_set='FS%d_V1' % train):
                self.assertEqual(len(ts['train']), train)
                self.assertEqual(len(ts['test']), test)
                # Test no duplicates
                self.assertEqual(len(ts['train'] + ts['test']),
                                 len(set(ts['train'] + ts['test'])))
                self.assertFalse(any(i in ts['test'] for i in ts['train']))

    def test_mt_v1(self):
        for ts, train in MT_V1:
            with self.subTest(task_set='MT%d_V1' % train):
                self.assertEqual(len(ts['train']), train)
                # Test no duplicates
                self.assertEqual(len(ts['train']), len(set(ts['train'])))