#!/usr/bin/env python3 """ CLI script to predict trace on an image using the trace model. Reuses load_model and run_inference from trace_inference. """ import argparse import os import shutil import sys from trace_inference import DEFAULT_MODEL_ID, build_prompt, load_model, run_inference def main(): parser = argparse.ArgumentParser( description="Predict trace/trajectory on an image using mihirgrao/trace-model" ) parser.add_argument("image", type=str, help="Path to input image") parser.add_argument( "-o", "--output", type=str, default=None, help="Path to save overlay image (default: _trace.png)", ) parser.add_argument( "-m", "--model-id", type=str, default=DEFAULT_MODEL_ID, help=f"Model ID (default: {DEFAULT_MODEL_ID})", ) parser.add_argument( "-i", "--instruction", type=str, default="", help="Natural language task instruction (e.g. 'Pick up the red block and place it on the table')", ) parser.add_argument( "-p", "--prompt", type=str, default=None, help="Full prompt override (if set, ignores --instruction)", ) args = parser.parse_args() if not os.path.exists(args.image): print(f"Error: Image not found: {args.image}", file=sys.stderr) sys.exit(1) # Load model success, msg = load_model(args.model_id) if not success: print(f"Error: {msg}", file=sys.stderr) sys.exit(1) print(f"āœ“ {msg}") # Build prompt from instruction prompt = args.prompt if args.prompt is not None else build_prompt(args.instruction) # Run inference prediction, overlay_path, trace_text = run_inference( args.image, prompt, args.model_id ) # Handle errors if prediction.startswith("Error:") or prediction.startswith("Please "): print(f"Error: {prediction}", file=sys.stderr) sys.exit(1) if overlay_path is None: print("\nModel prediction (raw):") print(prediction) print("\n" + trace_text) print("\nNo trajectory points were extracted from the prediction.") sys.exit(0) # Save overlay to desired path if specified output_path = args.output if output_path is None: base, ext = os.path.splitext(args.image) output_path = f"{base}_trace{ext}" shutil.copy(overlay_path, output_path) os.unlink(overlay_path) # Remove temp file print(f"\nāœ“ Overlay saved to: {output_path}") print("\nModel prediction (raw):") print(prediction) print("\n" + trace_text) if __name__ == "__main__": main()