| from . import VecEnvWrapper |
| from baselines.bench.monitor import ResultsWriter |
| import numpy as np |
| import time |
| from collections import deque |
|
|
| class VecMonitor(VecEnvWrapper): |
| def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()): |
| VecEnvWrapper.__init__(self, venv) |
| self.eprets = None |
| self.eplens = None |
| self.epcount = 0 |
| self.tstart = time.time() |
| if filename: |
| self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart}, |
| extra_keys=info_keywords) |
| else: |
| self.results_writer = None |
| self.info_keywords = info_keywords |
| self.keep_buf = keep_buf |
| if self.keep_buf: |
| self.epret_buf = deque([], maxlen=keep_buf) |
| self.eplen_buf = deque([], maxlen=keep_buf) |
|
|
| def reset(self): |
| obs = self.venv.reset() |
| self.eprets = np.zeros(self.num_envs, 'f') |
| self.eplens = np.zeros(self.num_envs, 'i') |
| return obs |
|
|
| def step_wait(self): |
| obs, rews, dones, infos = self.venv.step_wait() |
| self.eprets += rews |
| self.eplens += 1 |
|
|
| newinfos = list(infos[:]) |
| for i in range(len(dones)): |
| if dones[i]: |
| info = infos[i].copy() |
| ret = self.eprets[i] |
| eplen = self.eplens[i] |
| epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} |
| for k in self.info_keywords: |
| epinfo[k] = info[k] |
| info['episode'] = epinfo |
| if self.keep_buf: |
| self.epret_buf.append(ret) |
| self.eplen_buf.append(eplen) |
| self.epcount += 1 |
| self.eprets[i] = 0 |
| self.eplens[i] = 0 |
| if self.results_writer: |
| self.results_writer.write_row(epinfo) |
| newinfos[i] = info |
| return obs, rews, dones, newinfos |
|
|