File size: 2,714 Bytes
7c21061
 
 
 
8c5e6cc
7c21061
 
 
 
 
 
 
8c5e6cc
7c21061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e40307
7c21061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3
"""
CLI script to predict trace on an image using the trace model.

Reuses load_model and run_inference from trace_inference.
"""

import argparse
import os
import shutil
import sys

from trace_inference import DEFAULT_MODEL_ID, build_prompt, load_model, run_inference


def main():
    parser = argparse.ArgumentParser(
        description="Predict trace/trajectory on an image using mihirgrao/trace-model"
    )
    parser.add_argument("image", type=str, help="Path to input image")
    parser.add_argument(
        "-o",
        "--output",
        type=str,
        default=None,
        help="Path to save overlay image (default: <image>_trace.png)",
    )
    parser.add_argument(
        "-m",
        "--model-id",
        type=str,
        default=DEFAULT_MODEL_ID,
        help=f"Model ID (default: {DEFAULT_MODEL_ID})",
    )
    parser.add_argument(
        "-i",
        "--instruction",
        type=str,
        default="",
        help="Natural language task instruction (e.g. 'Pick up the red block and place it on the table')",
    )
    parser.add_argument(
        "-p",
        "--prompt",
        type=str,
        default=None,
        help="Full prompt override (if set, ignores --instruction)",
    )
    args = parser.parse_args()

    if not os.path.exists(args.image):
        print(f"Error: Image not found: {args.image}", file=sys.stderr)
        sys.exit(1)

    # Load model
    success, msg = load_model(args.model_id)
    if not success:
        print(f"Error: {msg}", file=sys.stderr)
        sys.exit(1)
    print(f"✓ {msg}")

    # Build prompt from instruction
    prompt = args.prompt if args.prompt is not None else build_prompt(args.instruction)

    # Run inference
    prediction, overlay_path, trace_text = run_inference(
        args.image, prompt, args.model_id
    )

    # Handle errors
    if prediction.startswith("Error:") or prediction.startswith("Please "):
        print(f"Error: {prediction}", file=sys.stderr)
        sys.exit(1)

    if overlay_path is None:
        print("\nModel prediction (raw):")
        print(prediction)
        print("\n" + trace_text)
        print("\nNo trajectory points were extracted from the prediction.")
        sys.exit(0)

    # Save overlay to desired path if specified
    output_path = args.output
    if output_path is None:
        base, ext = os.path.splitext(args.image)
        output_path = f"{base}_trace{ext}"

    shutil.copy(overlay_path, output_path)
    os.unlink(overlay_path)  # Remove temp file
    print(f"\n✓ Overlay saved to: {output_path}")

    print("\nModel prediction (raw):")
    print(prediction)

    print("\n" + trace_text)


if __name__ == "__main__":
    main()