| |
|
|
| import argparse |
| from pathlib import Path |
| from typing import Optional, Tuple |
|
|
| from omegaconf import OmegaConf, DictConfig |
|
|
| from .. import logger |
| from ..data import KittiDataModule |
| from .run import evaluate |
|
|
|
|
| default_cfg_single = OmegaConf.create({}) |
| |
| |
| |
| default_cfg_sequential = OmegaConf.create( |
| { |
| "data": { |
| "mask_radius": KittiDataModule.default_cfg["max_init_error"], |
| "prior_range_rotation": KittiDataModule.default_cfg[ |
| "max_init_error_rotation" |
| ] |
| + 1, |
| "max_init_error": 0, |
| "max_init_error_rotation": 0, |
| }, |
| "chunking": { |
| "max_length": 100, |
| }, |
| } |
| ) |
|
|
|
|
| def run( |
| split: str, |
| experiment: str, |
| cfg: Optional[DictConfig] = None, |
| sequential: bool = False, |
| thresholds: Tuple[int] = (1, 3, 5), |
| **kwargs, |
| ): |
| cfg = cfg or {} |
| if isinstance(cfg, dict): |
| cfg = OmegaConf.create(cfg) |
| default = default_cfg_sequential if sequential else default_cfg_single |
| cfg = OmegaConf.merge(default, cfg) |
| dataset = KittiDataModule(cfg.get("data", {})) |
|
|
| metrics = evaluate( |
| experiment, |
| cfg, |
| dataset, |
| split=split, |
| sequential=sequential, |
| viz_kwargs=dict(show_dir_error=True, show_masked_prob=False), |
| **kwargs, |
| ) |
|
|
| keys = ["directional_error", "yaw_max_error"] |
| if sequential: |
| keys += ["directional_seq_error", "yaw_seq_error"] |
| for k in keys: |
| rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist() |
| logger.info("Recall %s: %s at %s m/°", k, rec, thresholds) |
| return metrics |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--experiment", type=str, required=True) |
| parser.add_argument( |
| "--split", type=str, default="test", choices=["test", "val", "train"] |
| ) |
| parser.add_argument("--sequential", action="store_true") |
| parser.add_argument("--output_dir", type=Path) |
| parser.add_argument("--num", type=int) |
| parser.add_argument("dotlist", nargs="*") |
| args = parser.parse_args() |
| cfg = OmegaConf.from_cli(args.dotlist) |
| run( |
| args.split, |
| args.experiment, |
| cfg, |
| args.sequential, |
| output_dir=args.output_dir, |
| num=args.num, |
| ) |
|
|