| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| from ultralytics import YOLO |
|
|
| from models import register_ultralytics_modules |
|
|
|
|
| ROOT = Path(__file__).resolve().parent |
| DEFAULT_WEIGHTS = ROOT / "weights" / "symbolic_capsule_network_segmentation.pt" |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Run Symbolic Capsule Network segmentation inference.") |
| parser.add_argument("source", help="Image, directory, video, or glob pattern.") |
| parser.add_argument("--weights", default=str(DEFAULT_WEIGHTS), help="Checkpoint path.") |
| parser.add_argument("--imgsz", type=int, default=640) |
| parser.add_argument("--conf", type=float, default=0.25) |
| parser.add_argument("--device", default="") |
| parser.add_argument("--save", action="store_true", default=True) |
| parser.add_argument("--show", action="store_true") |
| return parser |
|
|
|
|
| def main() -> None: |
| args = build_parser().parse_args() |
| weights = Path(args.weights).expanduser().resolve() |
| if not weights.exists(): |
| raise FileNotFoundError(f"Checkpoint not found: {weights}") |
|
|
| register_ultralytics_modules() |
| model = YOLO(str(weights)) |
|
|
| predict_kwargs = {k: v for k, v in vars(args).items() if k != "weights"} |
| model.predict(**predict_kwargs) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|