Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| CLI script to predict trace on an image using the trace model. | |
| Reuses load_model and run_inference from trace_inference. | |
| """ | |
| import argparse | |
| import os | |
| import shutil | |
| import sys | |
| from trace_inference import DEFAULT_MODEL_ID, build_prompt, load_model, run_inference | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Predict trace/trajectory on an image using mihirgrao/trace-model" | |
| ) | |
| parser.add_argument("image", type=str, help="Path to input image") | |
| parser.add_argument( | |
| "-o", | |
| "--output", | |
| type=str, | |
| default=None, | |
| help="Path to save overlay image (default: <image>_trace.png)", | |
| ) | |
| parser.add_argument( | |
| "-m", | |
| "--model-id", | |
| type=str, | |
| default=DEFAULT_MODEL_ID, | |
| help=f"Model ID (default: {DEFAULT_MODEL_ID})", | |
| ) | |
| parser.add_argument( | |
| "-i", | |
| "--instruction", | |
| type=str, | |
| default="", | |
| help="Natural language task instruction (e.g. 'Pick up the red block and place it on the table')", | |
| ) | |
| parser.add_argument( | |
| "-p", | |
| "--prompt", | |
| type=str, | |
| default=None, | |
| help="Full prompt override (if set, ignores --instruction)", | |
| ) | |
| args = parser.parse_args() | |
| if not os.path.exists(args.image): | |
| print(f"Error: Image not found: {args.image}", file=sys.stderr) | |
| sys.exit(1) | |
| # Load model | |
| success, msg = load_model(args.model_id) | |
| if not success: | |
| print(f"Error: {msg}", file=sys.stderr) | |
| sys.exit(1) | |
| print(f"✓ {msg}") | |
| # Build prompt from instruction | |
| prompt = args.prompt if args.prompt is not None else build_prompt(args.instruction) | |
| # Run inference | |
| prediction, overlay_path, trace_text = run_inference( | |
| args.image, prompt, args.model_id | |
| ) | |
| # Handle errors | |
| if prediction.startswith("Error:") or prediction.startswith("Please "): | |
| print(f"Error: {prediction}", file=sys.stderr) | |
| sys.exit(1) | |
| if overlay_path is None: | |
| print("\nModel prediction (raw):") | |
| print(prediction) | |
| print("\n" + trace_text) | |
| print("\nNo trajectory points were extracted from the prediction.") | |
| sys.exit(0) | |
| # Save overlay to desired path if specified | |
| output_path = args.output | |
| if output_path is None: | |
| base, ext = os.path.splitext(args.image) | |
| output_path = f"{base}_trace{ext}" | |
| shutil.copy(overlay_path, output_path) | |
| os.unlink(overlay_path) # Remove temp file | |
| print(f"\n✓ Overlay saved to: {output_path}") | |
| print("\nModel prediction (raw):") | |
| print(prediction) | |
| print("\n" + trace_text) | |
| if __name__ == "__main__": | |
| main() | |