Spaces:
Sleeping
Sleeping
File size: 2,714 Bytes
7c21061 8c5e6cc 7c21061 8c5e6cc 7c21061 5e40307 7c21061 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
#!/usr/bin/env python3
"""
CLI script to predict trace on an image using the trace model.
Reuses load_model and run_inference from trace_inference.
"""
import argparse
import os
import shutil
import sys
from trace_inference import DEFAULT_MODEL_ID, build_prompt, load_model, run_inference
def main():
parser = argparse.ArgumentParser(
description="Predict trace/trajectory on an image using mihirgrao/trace-model"
)
parser.add_argument("image", type=str, help="Path to input image")
parser.add_argument(
"-o",
"--output",
type=str,
default=None,
help="Path to save overlay image (default: <image>_trace.png)",
)
parser.add_argument(
"-m",
"--model-id",
type=str,
default=DEFAULT_MODEL_ID,
help=f"Model ID (default: {DEFAULT_MODEL_ID})",
)
parser.add_argument(
"-i",
"--instruction",
type=str,
default="",
help="Natural language task instruction (e.g. 'Pick up the red block and place it on the table')",
)
parser.add_argument(
"-p",
"--prompt",
type=str,
default=None,
help="Full prompt override (if set, ignores --instruction)",
)
args = parser.parse_args()
if not os.path.exists(args.image):
print(f"Error: Image not found: {args.image}", file=sys.stderr)
sys.exit(1)
# Load model
success, msg = load_model(args.model_id)
if not success:
print(f"Error: {msg}", file=sys.stderr)
sys.exit(1)
print(f"✓ {msg}")
# Build prompt from instruction
prompt = args.prompt if args.prompt is not None else build_prompt(args.instruction)
# Run inference
prediction, overlay_path, trace_text = run_inference(
args.image, prompt, args.model_id
)
# Handle errors
if prediction.startswith("Error:") or prediction.startswith("Please "):
print(f"Error: {prediction}", file=sys.stderr)
sys.exit(1)
if overlay_path is None:
print("\nModel prediction (raw):")
print(prediction)
print("\n" + trace_text)
print("\nNo trajectory points were extracted from the prediction.")
sys.exit(0)
# Save overlay to desired path if specified
output_path = args.output
if output_path is None:
base, ext = os.path.splitext(args.image)
output_path = f"{base}_trace{ext}"
shutil.copy(overlay_path, output_path)
os.unlink(overlay_path) # Remove temp file
print(f"\n✓ Overlay saved to: {output_path}")
print("\nModel prediction (raw):")
print(prediction)
print("\n" + trace_text)
if __name__ == "__main__":
main()
|