File size: 1,874 Bytes
912c7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import hydra
import wandb
import subprocess
from omegaconf import OmegaConf, open_dict
from pfp import set_seeds, REPO_DIRS
from pfp.envs.rlbench_runner import RLBenchRunner
from pfp.policy.base_policy import BasePolicy
from pfp.common.visualization import RerunViewer as RV


@hydra.main(version_base=None, config_path="../conf", config_name="eval")
def main(cfg: OmegaConf):
    if not OmegaConf.has_resolver("eval"):
        OmegaConf.register_new_resolver("eval", eval)
    OmegaConf.resolve(cfg)
    print(OmegaConf.to_yaml(cfg))
    set_seeds(cfg.seed)

    # Download checkpoint if not present
    ckpt_path = REPO_DIRS.CKPT / cfg.policy.ckpt_name
    if not ckpt_path.exists():
        subprocess.run(
            [
                "rsync",
                "-hPrl",
                f"chisari@rlgpu2:{ckpt_path}",
                f"{REPO_DIRS.CKPT}/",
            ]
        )

    with open_dict(cfg):
        train_cfg = OmegaConf.load(ckpt_path / "config.yaml")
        cfg.model = train_cfg.model
        cfg.env_runner.env_config.task_name = train_cfg.task_name
        cfg.env_runner.env_config.obs_mode = train_cfg.obs_mode
        cfg.env_runner.env_config.use_pc_color = train_cfg.dataset.use_pc_color
        cfg.env_runner.env_config.n_points = train_cfg.dataset.n_points
        cfg.policy._target_ = train_cfg.model._target_ + ".load_from_checkpoint"

    print(OmegaConf.to_yaml(cfg))

    if cfg.env_runner.env_config.vis:
        RV("pfp_evaluate")
    wandb.init(
        project="pfp-eval-rebuttal",
        entity="rl-lab-chisari",
        config=OmegaConf.to_container(cfg),
        mode="online" if cfg.log_wandb else "disabled",
    )
    policy: BasePolicy = hydra.utils.instantiate(cfg.policy)
    env_runner = RLBenchRunner(**cfg.env_runner)
    _ = env_runner.run(policy)
    wandb.finish()
    return


if __name__ == "__main__":
    main()