from __future__ import annotations import argparse from pathlib import Path from ultralytics import YOLO from models import register_ultralytics_modules ROOT = Path(__file__).resolve().parent DEFAULT_WEIGHTS = ROOT / "weights" / "symbolic_capsule_network_segmentation.pt" def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run Symbolic Capsule Network segmentation inference.") parser.add_argument("source", help="Image, directory, video, or glob pattern.") parser.add_argument("--weights", default=str(DEFAULT_WEIGHTS), help="Checkpoint path.") parser.add_argument("--imgsz", type=int, default=640) parser.add_argument("--conf", type=float, default=0.25) parser.add_argument("--device", default="") parser.add_argument("--save", action="store_true", default=True) parser.add_argument("--show", action="store_true") return parser def main() -> None: args = build_parser().parse_args() weights = Path(args.weights).expanduser().resolve() if not weights.exists(): raise FileNotFoundError(f"Checkpoint not found: {weights}") register_ultralytics_modules() model = YOLO(str(weights)) predict_kwargs = {k: v for k, v in vars(args).items() if k != "weights"} model.predict(**predict_kwargs) if __name__ == "__main__": main()