| import pytest |
| try: |
| import mujoco_py |
| _mujoco_present = True |
| except BaseException: |
| mujoco_py = None |
| _mujoco_present = False |
|
|
|
|
| @pytest.mark.skipif( |
| not _mujoco_present, |
| reason='error loading mujoco - either mujoco / mujoco key not present, or LD_LIBRARY_PATH is not pointing to mujoco library' |
| ) |
| def test_lstm_example(): |
| import tensorflow as tf |
| from baselines.common import policies, models, cmd_util |
| from baselines.common.vec_env.dummy_vec_env import DummyVecEnv |
|
|
| |
| venv = DummyVecEnv([lambda: cmd_util.make_mujoco_env('Reacher-v2', seed=0)]) |
|
|
| with tf.compat.v1.Session() as sess: |
| |
| policy = policies.build_policy(venv, models.lstm(128))(nbatch=1, nsteps=1) |
|
|
| |
| sess.run(tf.compat.v1.global_variables_initializer()) |
|
|
| |
| ob = venv.reset() |
| state = policy.initial_state |
| done = [False] |
| step_counter = 0 |
|
|
| |
| while True: |
| action, _, state, _ = policy.step(ob, S=state, M=done) |
| ob, reward, done, _ = venv.step(action) |
| step_counter += 1 |
| if done: |
| break |
|
|
|
|
| assert step_counter > 5 |
|
|
|
|
|
|
|
|
|
|