Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| FastAPI server for Trace Model inference. | |
| Usage: | |
| python eval_server.py --model-id mihirgrao/trace-model --port 8000 | |
| Endpoints: | |
| POST /predict - Single image + instruction | |
| POST /predict_batch - Batch of (image, instruction) pairs | |
| GET /health - Health check | |
| GET /model_info - Model information | |
| """ | |
| import argparse | |
| import base64 | |
| import io | |
| import logging | |
| import os | |
| import re | |
| import tempfile | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor | |
| from threading import Lock | |
| from typing import Any, Dict, List, Optional | |
| import uvicorn | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from trace_inference import ( | |
| DEFAULT_MODEL_ID, | |
| build_prompt, | |
| load_model, | |
| run_inference, | |
| ) | |
| from trace_inference import _model_state as _trace_model_state | |
| from trajectory_viz import extract_trajectory_from_text | |
| logger = logging.getLogger(__name__) | |
| # --- Trace Eval Server --- | |
| class TraceEvalServer: | |
| """Inference server for the trace model.""" | |
| def __init__( | |
| self, | |
| model_id: str = DEFAULT_MODEL_ID, | |
| max_workers: int = 1, | |
| ): | |
| self.model_id = model_id | |
| self.max_workers = max_workers | |
| self._job_counter = 0 | |
| self._completed_jobs = 0 | |
| self._lock = Lock() | |
| self.executor = ThreadPoolExecutor(max_workers=max_workers) | |
| logger.info(f"Loading trace model: {model_id}") | |
| success, msg = load_model(model_id) | |
| if not success: | |
| raise RuntimeError(f"Failed to load model: {msg}") | |
| logger.info(msg) | |
| def predict_one( | |
| self, | |
| image_path: Optional[str] = None, | |
| image_base64: Optional[str] = None, | |
| instruction: str = "", | |
| is_oxe: bool = False, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run inference on a single image. | |
| Provide either image_path (file path) or image_base64 (base64-encoded image). | |
| """ | |
| if image_path is None and image_base64 is None: | |
| return {"error": "Provide image_path or image_base64"} | |
| temp_file_path = None | |
| if image_path is None: | |
| try: | |
| # Strip data URL prefix if present (e.g. "data:image/png;base64,") | |
| b64_str = image_base64.strip() | |
| if b64_str.startswith("data:"): | |
| match = re.match(r"data:image/[^;]+;base64,(.+)", b64_str, re.DOTALL) | |
| if match: | |
| b64_str = match.group(1) | |
| image_bytes = base64.b64decode(b64_str, validate=False) | |
| # Load via BytesIO to validate and get proper format, then save | |
| from PIL import Image | |
| img = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f: | |
| img.save(f.name, format="PNG") | |
| image_path = f.name | |
| temp_file_path = image_path | |
| except Exception as e: | |
| return {"error": f"Invalid image data: {e}"} | |
| try: | |
| prompt = build_prompt(instruction, is_oxe=is_oxe) | |
| prediction, _, _ = run_inference(image_path, prompt, self.model_id) | |
| finally: | |
| if temp_file_path and os.path.exists(temp_file_path): | |
| try: | |
| os.unlink(temp_file_path) | |
| except Exception: | |
| pass | |
| if prediction.startswith("Error:") or prediction.startswith("Please "): | |
| return {"error": prediction} | |
| trajectory = extract_trajectory_from_text(prediction) | |
| result: Dict[str, Any] = { | |
| "prediction": prediction, | |
| "trajectory": trajectory, | |
| } | |
| return result | |
| def predict_batch( | |
| self, | |
| samples: List[Dict[str, Any]], | |
| ) -> Dict[str, Any]: | |
| """Process a batch of (image_path or image_base64, instruction) samples.""" | |
| results = [] | |
| for sample in samples: | |
| with self._lock: | |
| self._job_counter += 1 | |
| job_id = self._job_counter | |
| start = time.time() | |
| result = self.predict_one( | |
| image_path=sample.get("image_path"), | |
| image_base64=sample.get("image_base64"), | |
| instruction=sample.get("instruction", ""), | |
| is_oxe=sample.get("is_oxe", False), | |
| ) | |
| elapsed = time.time() - start | |
| with self._lock: | |
| self._completed_jobs += 1 | |
| logger.debug(f"[job {job_id}] completed in {elapsed:.3f}s") | |
| results.append(result) | |
| return {"results": results} | |
| def get_status(self) -> Dict[str, Any]: | |
| """Get server status.""" | |
| return { | |
| "model_id": self.model_id, | |
| "max_workers": self.max_workers, | |
| "completed_jobs": self._completed_jobs, | |
| "job_counter": self._job_counter, | |
| } | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """Get model information.""" | |
| try: | |
| model = _trace_model_state.get("model") | |
| if model is None: | |
| return {"model_id": self.model_id, "status": "not_loaded"} | |
| all_params = sum(p.numel() for p in model.parameters()) | |
| return { | |
| "model_id": self.model_id, | |
| "model_class": model.__class__.__name__, | |
| "total_parameters": all_params, | |
| } | |
| except Exception as e: | |
| return {"model_id": self.model_id, "error": str(e)} | |
| def shutdown(self): | |
| """Shutdown the executor.""" | |
| self.executor.shutdown(wait=True) | |
| def create_app( | |
| model_id: str = DEFAULT_MODEL_ID, | |
| max_workers: int = 1, | |
| server: Optional[TraceEvalServer] = None, | |
| ) -> FastAPI: | |
| app = FastAPI(title="Trace Model Evaluation Server") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| trace_server = server or TraceEvalServer(model_id=model_id, max_workers=max_workers) | |
| async def predict(request: Request) -> Dict[str, Any]: | |
| """ | |
| Predict trace for a single image. | |
| JSON body: | |
| - image_path: (optional) path to image file | |
| - image_base64: (optional) base64-encoded image | |
| - instruction: natural language task description | |
| - is_oxe: (optional) if true, use OXE prompt format | |
| """ | |
| body = await request.json() | |
| return trace_server.predict_one( | |
| image_path=body.get("image_path"), | |
| image_base64=body.get("image_base64"), | |
| instruction=body.get("instruction", ""), | |
| is_oxe=body.get("is_oxe", False), | |
| ) | |
| async def predict_batch(request: Request) -> Dict[str, Any]: | |
| """ | |
| Predict trace for a batch of images. | |
| JSON body: | |
| - samples: list of {image_path?, image_base64?, instruction} | |
| """ | |
| body = await request.json() | |
| samples = body.get("samples", []) | |
| if not samples: | |
| return {"error": "samples list is required", "results": []} | |
| return trace_server.predict_batch(samples) | |
| async def evaluate_batch(request: Request) -> Dict[str, Any]: | |
| """ | |
| Alias for /predict_batch for compatibility with RFM-style clients. | |
| Accepts same format as /predict_batch. | |
| """ | |
| return await predict_batch(request) | |
| def health() -> Dict[str, Any]: | |
| """Health check.""" | |
| status = trace_server.get_status() | |
| return { | |
| "status": "healthy", | |
| "model_id": status["model_id"], | |
| } | |
| def model_info() -> Dict[str, Any]: | |
| """Get model information.""" | |
| return trace_server.get_model_info() | |
| def gpu_status() -> Dict[str, Any]: | |
| """Get server status (RFM-compatible endpoint name).""" | |
| return trace_server.get_status() | |
| async def shutdown_event(): | |
| trace_server.shutdown() | |
| return app | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Trace Model Evaluation Server") | |
| parser.add_argument( | |
| "--model-id", | |
| type=str, | |
| default=DEFAULT_MODEL_ID, | |
| help=f"Model ID (default: {DEFAULT_MODEL_ID})", | |
| ) | |
| parser.add_argument( | |
| "--host", | |
| type=str, | |
| default="0.0.0.0", | |
| help="Server host", | |
| ) | |
| parser.add_argument( | |
| "--port", | |
| type=int, | |
| default=8001, | |
| help="Server port", | |
| ) | |
| parser.add_argument( | |
| "--max-workers", | |
| type=int, | |
| default=1, | |
| help="Max worker threads for batch processing", | |
| ) | |
| args = parser.parse_args() | |
| logging.basicConfig(level=logging.INFO) | |
| app = create_app(model_id=args.model_id, max_workers=args.max_workers) | |
| print(f"Trace eval server starting on {args.host}:{args.port}") | |
| print(f"Model: {args.model_id}") | |
| uvicorn.run(app, host=args.host, port=args.port) | |
| if __name__ == "__main__": | |
| main() | |