Spaces:
Running
Running
| """ | |
| SQLForge — demo server. | |
| A FastAPI app that loads the fine-tuned text-to-SQL model and turns a natural | |
| -language question + database schema into a SQL query. When a real SQLite | |
| database is available it also *runs* the query, shows the results, and uses | |
| self-correction (feed the DB error back to the model and retry) so you can watch | |
| the agent fix its own mistakes. | |
| Run locally: | |
| uvicorn app.server:app --reload --port 8000 | |
| Then open http://localhost:8000 | |
| """ | |
| import os | |
| import sqlite3 | |
| import threading | |
| import time | |
| from pathlib import Path | |
| # the model + tokenizer are local; don't reach for the (flaky) HF CDN at serve time. | |
| os.environ.setdefault("HF_HUB_OFFLINE", "1") | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from sqlforge.exec_eval import run_sql, schema_from_sqlite | |
| from sqlforge.inference import generate_sql, generate_sql_with_retry, load_model | |
| ROOT = Path(__file__).resolve().parent.parent | |
| STATIC = Path(__file__).resolve().parent / "static" | |
| # --- config (all overridable by env, so HF Spaces / CI can swap paths) ---------- | |
| BASE_MODEL = os.environ.get("SQLFORGE_BASE", "models/qwen2.5-coder-1.5b") | |
| if (ROOT / BASE_MODEL).is_dir(): | |
| BASE_MODEL = str(ROOT / BASE_MODEL) | |
| ADAPTER = os.environ.get("SQLFORGE_ADAPTER", str(ROOT / "outputs" / "qwen2.5-coder-1.5b-sql")) | |
| FOUR_BIT = os.environ.get("SQLFORGE_4BIT", "1") == "1" | |
| DB_DIR = Path(os.environ.get("SQLFORGE_DB_DIR", | |
| str(ROOT / "data" / "spider_raw" / "spider_data" / "database"))) | |
| MAX_ROWS = 100 # cap result rows sent to the browser | |
| GEN_TOKENS = 192 # SQL is short; cap new tokens for a snappier demo | |
| MAX_RETRIES = 1 # one self-correction retry (worst case ~2 generations, not 3) | |
| # curated example databases — every question below is verified to run cleanly | |
| # against the real DB. The last car question deliberately triggers self-correction | |
| # (a JOIN it fixes itself) to showcase the agentic recovery succeeding. | |
| EXAMPLES = [ | |
| {"db_id": "concert_singer", "label": "Concerts & Singers", | |
| "questions": ["How many singers do we have?", | |
| "What is the average, minimum, and maximum age of all singers?", | |
| "Show the name and country of all singers ordered by age from oldest to youngest."]}, | |
| {"db_id": "pets_1", "label": "Students & Pets", | |
| "questions": ["How many pets are there?", | |
| "What is the average weight of all pets?", | |
| "How many students are there?"]}, | |
| {"db_id": "world_1", "label": "World (countries)", | |
| "questions": ["What are the names of all countries that became independent after 1950?", | |
| "How many countries have a republic as their form of government?", | |
| "What is the average life expectancy of countries in Africa?"]}, | |
| {"db_id": "car_1", "label": "Cars", | |
| "questions": ["How many continents are there?", | |
| "What is the maximum horsepower of any car?", | |
| "How many countries does each continent have? List continent id, name and count."]}, | |
| {"db_id": "student_transcripts_tracking", "label": "Student Transcripts", | |
| "questions": ["How many courses in total are listed?", | |
| "How many students are there?", | |
| "List the first and last name of every student."]}, | |
| ] | |
| # --- model state (loaded in the background so the server starts instantly) ------ | |
| STATE = {"status": "loading", "model": ADAPTER, "device": None, "error": None} | |
| _MODEL = {"model": None, "tok": None} | |
| _LOCK = threading.Lock() # one generation at a time (single GPU) | |
| def _load(): | |
| try: | |
| model, tok = load_model(BASE_MODEL, adapter_path=ADAPTER, four_bit=FOUR_BIT) | |
| _MODEL["model"], _MODEL["tok"] = model, tok | |
| STATE["device"] = str(getattr(model, "device", "cuda")) | |
| STATE["status"] = "online" | |
| print(f"[sqlforge] model online ({STATE['device']}, 4bit={FOUR_BIT})") | |
| except Exception as exc: # noqa: BLE001 | |
| STATE["status"] = "error" | |
| STATE["error"] = str(exc) | |
| print(f"[sqlforge] model failed to load: {exc}") | |
| app = FastAPI(title="SQLForge", description="Fine-tuned text-to-SQL demo") | |
| def _startup(): | |
| threading.Thread(target=_load, daemon=True).start() | |
| def _db_path(db_id: str) -> Path | None: | |
| """Resolve a known example db_id to its .sqlite file (no path traversal).""" | |
| if not db_id or "/" in db_id or "\\" in db_id or ".." in db_id: | |
| return None | |
| p = DB_DIR / db_id / f"{db_id}.sqlite" | |
| return p if p.exists() else None | |
| # --- API ------------------------------------------------------------------------ | |
| class GenerateRequest(BaseModel): | |
| question: str | |
| schema_text: str | None = None # raw CREATE TABLE text (custom mode) | |
| db_id: str | None = None # example DB id (executes + self-corrects) | |
| self_correct: bool = True | |
| def health(): | |
| return STATE | |
| def examples(): | |
| return [e for e in EXAMPLES if _db_path(e["db_id"])] | |
| def schema(db_id: str): | |
| p = _db_path(db_id) | |
| if not p: | |
| raise HTTPException(404, f"unknown database '{db_id}'") | |
| return {"db_id": db_id, "schema": schema_from_sqlite(p)} | |
| def generate(req: GenerateRequest): | |
| if STATE["status"] != "online": | |
| raise HTTPException(503, f"model not ready ({STATE['status']})") | |
| if not req.question.strip(): | |
| raise HTTPException(400, "question is required") | |
| db_path = _db_path(req.db_id) if req.db_id else None | |
| schema_text = req.schema_text | |
| if db_path and not schema_text: | |
| schema_text = schema_from_sqlite(db_path) | |
| if not schema_text: | |
| raise HTTPException(400, "provide a schema or pick an example database") | |
| model, tok = _MODEL["model"], _MODEL["tok"] | |
| trace: list = [] | |
| t0 = time.time() | |
| with _LOCK: | |
| if db_path and req.self_correct: | |
| sql, attempts = generate_sql_with_retry( | |
| model, tok, schema_text, req.question, | |
| validate=lambda s: run_sql(db_path, s)[1], | |
| max_retries=MAX_RETRIES, max_new_tokens=GEN_TOKENS, trace=trace) | |
| else: | |
| sql = generate_sql(model, tok, schema_text, req.question, | |
| max_new_tokens=GEN_TOKENS) | |
| attempts = 1 | |
| elapsed = round(time.time() - t0, 2) | |
| resp = {"sql": sql, "attempts": attempts, "trace": trace, | |
| "self_corrected": attempts > 1, "elapsed_s": elapsed, | |
| "executed": False, "columns": None, "rows": None, | |
| "row_count": None, "error": None} | |
| # if we have the real DB, run the final query and return the result preview | |
| if db_path: | |
| conn = sqlite3.connect(str(db_path)) | |
| try: | |
| cur = conn.execute(sql) | |
| cols = [c[0] for c in cur.description] if cur.description else [] | |
| rows = cur.fetchmany(MAX_ROWS) | |
| resp.update(executed=True, columns=cols, | |
| rows=[list(r) for r in rows], row_count=len(rows)) | |
| except Exception as exc: # noqa: BLE001 | |
| resp.update(executed=True, error=str(exc)) | |
| finally: | |
| conn.close() | |
| return resp | |
| # --- static frontend (mounted last so /api/* wins) ------------------------------ | |
| def index(): | |
| return FileResponse(STATIC / "index.html") | |
| app.mount("/", StaticFiles(directory=str(STATIC)), name="static") | |