cccode / inference_avwm.py
WayneW's picture
Upload folder using huggingface_hub
705a8fd verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from distributed import init_distributed
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import yaml
import argparse
import os
import numpy as np
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
import misc
import distributed as dist
from models import AVCDiT_models
from datasets import EvalDataset
from PIL import Image
from soundstream import SoundStream
import torchaudio
from skimage.measure import block_reduce
import matplotlib.pyplot as plt
import librosa
import time
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from collections import defaultdict
import json
def save_image(output_file, img, unnormalize_img):
img = img.detach().cpu()
if unnormalize_img:
img = misc.unnormalize(img)
img = img * 255
img = img.byte()
image = Image.fromarray(img.permute(1, 2, 0).numpy(), mode='RGB')
image.save(output_file)
def save_audio(output_file, audio_tensor, sample_rate):
audio_tensor = audio_tensor.detach().cpu()
if audio_tensor.ndim == 1:
audio_tensor = audio_tensor.unsqueeze(0)
torchaudio.save(output_file, audio_tensor.to(torch.float32), sample_rate)
def get_dataset_eval(config, dataset_name, eval_type, predefined_index=True):
data_config = config["eval_datasets"][dataset_name]
if predefined_index:
predefined_index = f"data_splits/{dataset_name}/test/{eval_type}.pkl"
else:
predefined_index=None
dataset = EvalDataset(
data_folder=data_config["data_folder"],
data_split_folder=data_config["test"],
dataset_name=dataset_name,
image_size=config["image_size"],
min_dist_cat=config["eval_distance"]["eval_min_dist_cat"],
max_dist_cat=config["eval_distance"]["eval_max_dist_cat"],
len_traj_pred=config["eval_len_traj_pred"],
traj_stride=config["traj_stride"],
context_size=config["eval_context_size"],
normalize=config["normalize"],
transform=misc.transform,
goals_per_obs=4,
predefined_index=predefined_index,
traj_names='traj_names.txt'
)
return dataset
@torch.no_grad()
def model_forward_wrapper_v(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False):
model, diffusion, vae = all_models
x = curr_obs.to(device)
y = curr_delta.to(device)
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
B, T = x.shape[:2]
if rel_t is None:
rel_t = (torch.ones(B)* (1. / 128.)).to(device)
rel_t *= num_timesteps
x = x.flatten(0,1)
x = vae.encode(x).latent_dist.sample().mul_(0.18215).unflatten(0, (B, T))
x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3], x.shape[4]).flatten(0, 1)
z = torch.randn(B*num_goals, 4, latent_size, latent_size, device=device)
y = y.flatten(0, 1)
model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t)
samples = diffusion.p_sample_loop(
model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device
)
samples = vae.decode(samples / 0.18215).sample
return torch.clip(samples, -1., 1.)
@torch.no_grad()
def model_forward_wrapper_a(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False):
model, diffusion, sstream = all_models
x = curr_obs.to(device)
y = curr_delta.to(device)
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
B, T = x.shape[:2]
if rel_t is None:
rel_t = (torch.ones(B)* (1. / 128.)).to(device)
rel_t *= num_timesteps
x = x.flatten(0,1)
x = sstream.encoder(x).unflatten(0, (B, T))
x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3]).flatten(0, 1)
z = torch.randn(B*num_goals, 16, 181, device=device)
y = y.flatten(0, 1)
model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t)
samples = diffusion.p_sample_loop(
model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device
)
# REWARD TOKEN
patch_tok = samples[..., -1:] # [N, 64, 1]
diff_pred = patch_tok.mean(dim=1, keepdim=True) # [N, 1]
samples = samples[..., :-1]
# AUDIO TOKENS
quantized, _, _ = sstream.quantizer(samples.permute(0, 2, 1)) # [1, T', D]
samples = sstream.decoder(quantized.permute(0, 2, 1))
return samples, diff_pred
@torch.no_grad()
def model_forward_wrapper_av(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False):
model, diffusion, vae, sstream = all_models
x_v, x_a = curr_obs
x_v = x_v.to(device)
x_a = x_a.to(device)
y = curr_delta.to(device)
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
B, T_v = x_v.shape[:2]
B, T_a = x_a.shape[:2]
if rel_t is None:
rel_t = (torch.ones(B)* (1. / 128.)).to(device)
rel_t *= num_timesteps
x_v = x_v.flatten(0,1)
x_a = x_a.flatten(0,1)
x_v = vae.encode(x_v).latent_dist.sample().mul_(0.18215).unflatten(0, (B, T_v))
x_a = sstream.encoder(x_a).unflatten(0, (B, T_a))
x_v_cond = x_v[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_v.shape[2], x_v.shape[3], x_v.shape[4]).flatten(0, 1)
x_a_cond = x_a[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a.shape[2], x_a.shape[3]).flatten(0, 1)
z_v = torch.randn(B*num_goals, 4, latent_size, latent_size, device=device)
z_a = torch.randn(B*num_goals, 16, 181, device=device) #TODO
y = y.flatten(0, 1)
model_kwargs = dict(y=y, x_v_cond=x_v_cond, x_a_cond=x_a_cond, rel_t=rel_t)
samples_v, samples_a = diffusion.p_sample_loop(
model.forward, z_v.shape, z_a.shape, z_v, z_a, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device
)
patch_tok = samples_a[..., -1:] # [N, 16, 1]
diff_pred = patch_tok.mean(dim=1, keepdim=True) # [N, 1]
samples_a = samples_a[..., :-1]
samples_v = vae.decode(samples_v / 0.18215).sample
quantized, _, _ = sstream.quantizer(samples_a.permute(0, 2, 1)) # [1, T', D]
samples_a = sstream.decoder(quantized.permute(0, 2, 1))
return torch.clip(samples_v, -1., 1.), samples_a, diff_pred
def generate_rollout(args, output_dir, rollout_frames, idxs, all_models, obs_av, gt_av, diffs_seq, delta, num_cond, device):
(obs_image, obs_audio, orig_obs_audio)=obs_av
(gt_image, gt_audio, orig_gt_audio)=gt_av
gt_image = gt_image[:,:rollout_frames]
gt_audio = gt_audio[:,:rollout_frames]
curr_v = obs_image.to(device)
curr_a = obs_audio.to(device)
down_resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000, lowpass_filter_width=64).to(device, dtype=torch.bfloat16)
episode_records = defaultdict(list)
value_key = "denorm_gt" if args.gt else "denorm_pred"
for i in range(gt_image.shape[1]):
curr_delta = delta[:, i:i+1].to(device)
x_gt_pixels = gt_image[:, i].to(device)
x_gt_audios_orig = orig_gt_audio[:, i].to(device)
if args.gt:
visualize_preds(output_dir, idxs, i+1, x_gt_pixels, x_gt_audios_orig, 16000)
denorm_gt_vals = denorm_from_tensor(diffs_seq[:, i:i+1, :]) # [B]
idxs_1d = idxs.detach().view(-1).cpu().numpy()
for b, sample_idx in enumerate(idxs_1d):
episode_records[int(sample_idx)].append({"sec": int(i+1), "value": float(denorm_gt_vals[b])})
else:
diff_gt = diffs_seq[:, i:i+1, :].unsqueeze(1).to(device)
x_pred_pixels, x_pred_audios, diff_pred = model_forward_wrapper_av(all_models, (curr_v, curr_a), curr_delta, num_timesteps=1, latent_size=args.latent_size, device=device, num_cond=num_cond, num_goals=1)
x_pred_audios_orig = down_resampler(x_pred_audios)
curr_v = torch.cat((curr_v, x_pred_pixels.unsqueeze(1)), dim=1) # append current prediction
curr_v = curr_v[:, 1:] # remove first observation
curr_a = torch.cat((curr_a, x_pred_audios.unsqueeze(1)), dim=1) # append current prediction
curr_a = curr_a[:, 1:] # remove first observation
denorm_pred_vals = denorm_from_tensor(diff_pred) # [B]
denorm_gt_vals = denorm_from_tensor(diff_gt) # [B]
visualize_preds(output_dir, idxs, i+1, x_pred_pixels, x_pred_audios_orig, 16000)
visualize_compare(output_dir, idxs, i+1,
x_pred_pixels, x_pred_audios_orig,
x_gt_pixels, x_gt_audios_orig,
denorm_pred_vals=denorm_pred_vals,
denorm_gt_vals=denorm_gt_vals)
idxs_1d = idxs.detach().view(-1).cpu().numpy()
for b, sample_idx in enumerate(idxs_1d):
episode_records[int(sample_idx)].append({"sec": int(i+1), "value": float(denorm_pred_vals[b])})
for sample_idx, rows in episode_records.items():
rows = sorted(rows, key=lambda r: r["sec"])
sample_folder = os.path.join(output_dir, f"id_{sample_idx}")
os.makedirs(sample_folder, exist_ok=True)
out_json = os.path.join(sample_folder, "distance.json")
compact = [{ "sec": r["sec"], value_key: r["value"] } for r in rows]
with open(out_json, "w") as f:
json.dump(compact, f, indent=2)
def generate_time(args, output_dir, idxs, all_models, obs_av, gt_av, diffs_seq, delta, secs, num_cond, device):
(obs_image, obs_audio, _)=obs_av
(gt_image, _, orig_gt_audio)=gt_av
down_resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000, lowpass_filter_width=64).to(device, dtype=torch.bfloat16)
episode_records = defaultdict(list) # {sample_idx: [{"sec": int, "value": float}, ...]}
value_key = "denorm_gt" if args.gt else "denorm_pred"
for sec in secs:
curr_delta = delta[:, :sec].sum(dim=1, keepdim=True)
x_gt_pixels = gt_image[:, sec-1].to(device)
x_gt_audios_orig = orig_gt_audio[:, sec-1].to(device)
if args.gt:
denorm_gt_vals = denorm_from_tensor(diffs_seq[:, :sec, :].sum(dim=1, keepdim=True)) # [B]
visualize_preds(output_dir, idxs, sec, x_gt_pixels, x_gt_audios_orig, 16000)
idxs_1d = idxs.detach().view(-1).cpu().numpy()
for b, sample_idx in enumerate(idxs_1d):
episode_records[int(sample_idx)].append({"sec": int(sec), "value": float(denorm_gt_vals[b])})
else:
diff_gt = diffs_seq[:, :sec, :].sum(dim=1, keepdim=True).to(device)
print(obs_image.shape, obs_audio.shape, curr_delta.shape, obs_image.dtype, obs_audio.dtype, curr_delta.dtype)
x_pred_pixels, x_pred_audios, diff_pred = model_forward_wrapper_av(all_models, (obs_image, obs_audio) , curr_delta, sec, args.latent_size, num_cond=num_cond, num_goals=1, device=device)
x_pred_audios_orig = down_resampler(x_pred_audios)
denorm_pred_vals = denorm_from_tensor(diff_pred) # [B]
denorm_gt_vals = denorm_from_tensor(diff_gt) # [B]
visualize_preds(output_dir, idxs, sec, x_pred_pixels, x_pred_audios_orig, 16000)
visualize_compare(output_dir, idxs, sec,
x_pred_pixels, x_pred_audios_orig,
x_gt_pixels, x_gt_audios_orig,
denorm_pred_vals=denorm_pred_vals,
denorm_gt_vals=denorm_gt_vals)
idxs_1d = idxs.detach().view(-1).cpu().numpy()
for b, sample_idx in enumerate(idxs_1d):
episode_records[int(sample_idx)].append({"sec": int(sec), "value": float(denorm_pred_vals[b])})
for sample_idx, rows in episode_records.items():
rows = sorted(rows, key=lambda r: r["sec"])
sample_folder = os.path.join(output_dir, f"id_{sample_idx}")
os.makedirs(sample_folder, exist_ok=True)
out_json = os.path.join(sample_folder, "distance.json")
compact = [{ "sec": r["sec"], value_key: r["value"] } for r in rows]
with open(out_json, "w") as f:
json.dump(compact, f, indent=2)
def visualize_preds(output_dir, idxs, sec, x_pred_pixels, x_pred_audios, sample_rate):
idxs_1d = idxs.detach().view(-1)
for batch_idx, sample_idx in enumerate(idxs_1d):
sample_idx = int(sample_idx.item())
sample_folder = os.path.join(output_dir, f'id_{sample_idx}')
os.makedirs(sample_folder, exist_ok=True)
image_file = os.path.join(sample_folder, f'{sec}.png')
save_image(image_file, x_pred_pixels[batch_idx], True)
audio_file = os.path.join(sample_folder, f'{sec}.wav')
save_audio(audio_file, x_pred_audios[batch_idx], sample_rate)
def _compute_binaural_spectrogram_np(audio_2ch: np.ndarray):
def _stft_abs(signal):
n_fft = 512
hop_length = 160
win_length = 400
stft = np.abs(librosa.stft(signal, n_fft=n_fft, hop_length=hop_length, win_length=win_length))
stft = block_reduce(stft, block_size=(4, 4), func=np.mean)
return stft
L = np.log1p(_stft_abs(audio_2ch[0]))
R = np.log1p(_stft_abs(audio_2ch[1]))
spec = np.stack([L, R], axis=-1) # (F,T,2)
return spec
def denorm_from_tensor(t: torch.Tensor, min_v=-20.0, max_v=20.0, scale=0.15) -> torch.Tensor:
x = t.detach().float().view(t.shape[0], -1)[:, 0]
n01 = (x + 1.0) / 2.0
raw = n01 * (max_v - min_v) + min_v
return raw * scale
def visualize_compare(output_dir, idxs, sec,
x_pred_pixels, x_pred_audios_orig,
x_gt_pixels, x_gt_audios_orig,
denorm_pred_vals,
denorm_gt_vals):
idxs_np = idxs.detach().view(-1).cpu().numpy()
B = x_pred_pixels.shape[0]
assert x_gt_pixels.shape[0] == B and x_pred_audios_orig.shape[0] == B and x_gt_audios_orig.shape[0] == B
for b in range(B):
sample_idx = int(idxs_np[b])
sample_folder = os.path.join(output_dir, f'id_{sample_idx}')
os.makedirs(sample_folder, exist_ok=True)
out_path = os.path.join(sample_folder, f'compare_{sec}.png')
def _tensor_to_display_img(x: torch.Tensor):
x = x.detach().cpu()
x = misc.unnormalize(x)
x = (x * 255.0).round().clamp(0, 255)
x = x.to(torch.uint8).permute(1, 2, 0)
return x.numpy()
pred_img = _tensor_to_display_img(x_pred_pixels[b])
gt_img = _tensor_to_display_img(x_gt_pixels[b])
pred_aud = x_pred_audios_orig[b].detach().cpu().float().numpy()
gt_aud = x_gt_audios_orig[b].detach().cpu().float().numpy()
pred_spec = _compute_binaural_spectrogram_np(pred_aud)
gt_spec = _compute_binaural_spectrogram_np(gt_aud)
vmin_L = min(pred_spec[:, :, 0].min(), gt_spec[:, :, 0].min())
vmax_L = max(pred_spec[:, :, 0].max(), gt_spec[:, :, 0].max())
vmin_R = min(pred_spec[:, :, 1].min(), gt_spec[:, :, 1].min())
vmax_R = max(pred_spec[:, :, 1].max(), gt_spec[:, :, 1].max())
dn_pred = float(denorm_pred_vals[b]) if denorm_pred_vals is not None else 0
dn_gt = float(denorm_gt_vals[b]) if denorm_gt_vals is not None else 0
fig, axes = plt.subplots(2, 4, figsize=(14, 6), constrained_layout=True)
axes[0, 0].imshow(pred_img); axes[0, 0].set_title('pred image'); axes[0, 0].axis('off')
axes[0, 1].imshow(gt_img); axes[0, 1].set_title('gt image'); axes[0, 1].axis('off')
axes[1, 0].axis('off')
axes[1, 1].axis('off')
im_pred_L = axes[0, 2].imshow(pred_spec[:, :, 0], origin='lower', aspect='auto', vmin=vmin_L, vmax=vmax_L)
axes[0, 2].set_title('pred spec (Left)'); axes[0, 2].set_xticks([]); axes[0, 2].set_yticks([])
im_gt_L = axes[0, 3].imshow(gt_spec[:, :, 0], origin='lower', aspect='auto', vmin=vmin_L, vmax=vmax_L)
axes[0, 3].set_title('gt spec (Left)'); axes[0, 3].set_xticks([]); axes[0, 3].set_yticks([])
im_pred_R = axes[1, 2].imshow(pred_spec[:, :, 1], origin='lower', aspect='auto', vmin=vmin_R, vmax=vmax_R)
axes[1, 2].set_title('pred spec (Right)'); axes[1, 2].set_xticks([]); axes[1, 2].set_yticks([])
im_gt_R = axes[1, 3].imshow(gt_spec[:, :, 1], origin='lower', aspect='auto', vmin=vmin_R, vmax=vmax_R)
axes[1, 3].set_title('gt spec (Right)'); axes[1, 3].set_xticks([]); axes[1, 3].set_yticks([])
fig.suptitle(
f'id={sample_idx}, sec={sec} | denorm(reward_pred)={dn_pred:.4f}, denorm(reward_gt)={dn_gt:.4f}',
fontsize=11
)
plt.savefig(out_path, dpi=180)
plt.close(fig)
@torch.no_grad()
def main(args):
_, _, device, _ = init_distributed()
print(args)
device = torch.device(device)
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
exp_eval = args.exp
# model & config setup
if args.gt:
args.save_output_dir = os.path.join(args.output_dir, 'gt')
else:
exp_name = os.path.basename(exp_eval).split('.')[0]
args.save_output_dir = os.path.join(args.output_dir, exp_name)
if args.ckp != '0100000':
args.save_output_dir = args.save_output_dir + "_%s"%(args.ckp)
os.makedirs(args.save_output_dir, exist_ok=True)
with open("config/eval_config.yaml", "r") as f:
default_config = yaml.safe_load(f)
config = default_config
with open(exp_eval, "r") as f:
user_config = yaml.safe_load(f)
config.update(user_config)
eval_len_traj_pred=config["eval_len_traj_pred"]
if args.rollout_frames==-1:
args.rollout_frames=eval_len_traj_pred
assert args.rollout_frames<=eval_len_traj_pred
latent_size = config['image_size'] // 8
args.latent_size = config['image_size'] // 8
num_cond = config['context_size']
print("loading")
model_lst = (None, None, None, None)
if not args.gt:
model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4, mode="av")
ckp = torch.load(f'{config["results_dir"]}/{config["run_name"]}/checkpoints/{args.ckp}.pth.tar', map_location='cpu', weights_only=False)
print(model.load_state_dict(ckp["ema"], strict=True))
model.eval()
model.to(device)
model = torch.compile(model)
diffusion = create_diffusion(str(250), dual=True)
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
sstream = SoundStream(C=32, D=16, n_q=8, codebook_size=1024).to(device)
sstream_path=config["tokenizer_a_path"]
sstream_checkpoint = torch.load(sstream_path, map_location=device)
sstream.load_state_dict(sstream_checkpoint["model_state"])
sstream.eval()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], find_unused_parameters=False)
model_lst = (model, diffusion, vae, sstream)
# Loading Datasets
dataset_names = args.datasets.split(',')
datasets = {}
for dataset_name in dataset_names:
dataset_val = get_dataset_eval(config, dataset_name, args.eval_type, predefined_index=False)
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
curr_data_loader = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False
)
datasets[dataset_name] = curr_data_loader
print_freq = 1
header = 'Evaluation: '
metric_logger = dist.MetricLogger(delimiter=" ")
for dataset_name in dataset_names:
dataset_save_output_dir = os.path.join(args.save_output_dir, dataset_name)
os.makedirs(dataset_save_output_dir, exist_ok=True)
curr_data_loader = datasets[dataset_name]
for data_iter_step, (idxs, obs_image, gt_image, obs_audio, gt_audio, diffs_seq, delta, orig_obs_audio, orig_gt_audio) in enumerate(metric_logger.log_every(curr_data_loader, print_freq, header)):
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
obs_image = obs_image[:, -num_cond:].to(device)
gt_image = gt_image.to(device)
obs_audio = obs_audio[:, -num_cond:].to(device)
gt_audio = gt_audio.to(device)
orig_obs_audio = orig_obs_audio[:, -num_cond:].to(device)
orig_gt_audio = orig_gt_audio.to(device)
diffs_seq = diffs_seq.to(device)
obs_av=(obs_image, obs_audio, orig_obs_audio)
gt_av=(gt_image, gt_audio, orig_gt_audio)
if args.eval_type == 'rollout':
curr_rollout_output_dir = os.path.join(dataset_save_output_dir, f'rollout_{args.rollout_frames}frames')
os.makedirs(curr_rollout_output_dir, exist_ok=True)
generate_rollout(args, curr_rollout_output_dir, args.rollout_frames, idxs, model_lst, obs_av, gt_av, diffs_seq, delta, num_cond, device)
elif args.eval_type == 'time':
if args.time_secs != '':
secs = np.array([int(sec) for sec in args.time_secs.split(',')])
else:
secs = np.array([int(sec) for sec in range(1,args.rollout_frames+1)])
curr_time_output_dir = os.path.join(dataset_save_output_dir, 'time')
os.makedirs(curr_time_output_dir, exist_ok=True)
generate_time(args, curr_time_output_dir, idxs, model_lst, obs_av, gt_av, diffs_seq, delta, secs, num_cond, device)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default=None, help="output directory")
parser.add_argument("--exp", type=str, default=None, help="experiment name")
parser.add_argument("--ckp", type=str, default='0100000')
parser.add_argument("--num_sec_eval", type=int, default=5)
parser.add_argument("--input_fps", type=int, default=4)
parser.add_argument("--datasets", type=str, default=None, help="dataset name")
parser.add_argument("--num_workers", type=int, default=8, help="num workers")
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
parser.add_argument("--eval_type", type=str, default=None, help="type of evaluation has to be either 'time' or 'rollout'")
# Rollout Evaluation Args
parser.add_argument("--time_secs", type=str, default='', help="") #'1,2,3,4'
parser.add_argument("--rollout_frames", type=int, default=-1, help="")
parser.add_argument("--gt", type=int, default=0, help="set to 1 to produce ground truth evaluation set")
args = parser.parse_args()
main(args)