| import pytest |
| from baselines.common.tests.envs.fixed_sequence_env import FixedSequenceEnv |
|
|
| from baselines.common.tests.util import simple_test |
| from baselines.run import get_learn_function |
| from baselines.common.tests import mark_slow |
|
|
|
|
| common_kwargs = dict( |
| seed=0, |
| total_timesteps=50000, |
| ) |
|
|
| learn_kwargs = { |
| 'a2c': {}, |
| 'ppo2': dict(nsteps=10, ent_coef=0.0, nminibatches=1), |
| |
| |
| |
| } |
|
|
|
|
| alg_list = learn_kwargs.keys() |
| rnn_list = ['lstm'] |
|
|
| @mark_slow |
| @pytest.mark.parametrize("alg", alg_list) |
| @pytest.mark.parametrize("rnn", rnn_list) |
| def test_fixed_sequence(alg, rnn): |
| ''' |
| Test if the algorithm (with a given policy) |
| can learn an identity transformation (i.e. return observation as an action) |
| ''' |
|
|
| kwargs = learn_kwargs[alg] |
| kwargs.update(common_kwargs) |
|
|
| env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5) |
| learn = lambda e: get_learn_function(alg)( |
| env=e, |
| network=rnn, |
| **kwargs |
| ) |
|
|
| simple_test(env_fn, learn, 0.7) |
|
|
|
|
| if __name__ == '__main__': |
| test_fixed_sequence('ppo2', 'lstm') |
|
|
|
|
|
|
|
|