sqlforge / app /server.py
Abdullahkousa2's picture
Upload folder using huggingface_hub
d6bfc8b verified
Raw
History Blame Contribute Delete
7.57 kB
"""
SQLForge — demo server.
A FastAPI app that loads the fine-tuned text-to-SQL model and turns a natural
-language question + database schema into a SQL query. When a real SQLite
database is available it also *runs* the query, shows the results, and uses
self-correction (feed the DB error back to the model and retry) so you can watch
the agent fix its own mistakes.
Run locally:
uvicorn app.server:app --reload --port 8000
Then open http://localhost:8000
"""
import os
import sqlite3
import threading
import time
from pathlib import Path
# the model + tokenizer are local; don't reach for the (flaky) HF CDN at serve time.
os.environ.setdefault("HF_HUB_OFFLINE", "1")
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from sqlforge.exec_eval import run_sql, schema_from_sqlite
from sqlforge.inference import generate_sql, generate_sql_with_retry, load_model
ROOT = Path(__file__).resolve().parent.parent
STATIC = Path(__file__).resolve().parent / "static"
# --- config (all overridable by env, so HF Spaces / CI can swap paths) ----------
BASE_MODEL = os.environ.get("SQLFORGE_BASE", "models/qwen2.5-coder-1.5b")
if (ROOT / BASE_MODEL).is_dir():
BASE_MODEL = str(ROOT / BASE_MODEL)
ADAPTER = os.environ.get("SQLFORGE_ADAPTER", str(ROOT / "outputs" / "qwen2.5-coder-1.5b-sql"))
FOUR_BIT = os.environ.get("SQLFORGE_4BIT", "1") == "1"
DB_DIR = Path(os.environ.get("SQLFORGE_DB_DIR",
str(ROOT / "data" / "spider_raw" / "spider_data" / "database")))
MAX_ROWS = 100 # cap result rows sent to the browser
GEN_TOKENS = 192 # SQL is short; cap new tokens for a snappier demo
MAX_RETRIES = 1 # one self-correction retry (worst case ~2 generations, not 3)
# curated example databases — every question below is verified to run cleanly
# against the real DB. The last car question deliberately triggers self-correction
# (a JOIN it fixes itself) to showcase the agentic recovery succeeding.
EXAMPLES = [
{"db_id": "concert_singer", "label": "Concerts & Singers",
"questions": ["How many singers do we have?",
"What is the average, minimum, and maximum age of all singers?",
"Show the name and country of all singers ordered by age from oldest to youngest."]},
{"db_id": "pets_1", "label": "Students & Pets",
"questions": ["How many pets are there?",
"What is the average weight of all pets?",
"How many students are there?"]},
{"db_id": "world_1", "label": "World (countries)",
"questions": ["What are the names of all countries that became independent after 1950?",
"How many countries have a republic as their form of government?",
"What is the average life expectancy of countries in Africa?"]},
{"db_id": "car_1", "label": "Cars",
"questions": ["How many continents are there?",
"What is the maximum horsepower of any car?",
"How many countries does each continent have? List continent id, name and count."]},
{"db_id": "student_transcripts_tracking", "label": "Student Transcripts",
"questions": ["How many courses in total are listed?",
"How many students are there?",
"List the first and last name of every student."]},
]
# --- model state (loaded in the background so the server starts instantly) ------
STATE = {"status": "loading", "model": ADAPTER, "device": None, "error": None}
_MODEL = {"model": None, "tok": None}
_LOCK = threading.Lock() # one generation at a time (single GPU)
def _load():
try:
model, tok = load_model(BASE_MODEL, adapter_path=ADAPTER, four_bit=FOUR_BIT)
_MODEL["model"], _MODEL["tok"] = model, tok
STATE["device"] = str(getattr(model, "device", "cuda"))
STATE["status"] = "online"
print(f"[sqlforge] model online ({STATE['device']}, 4bit={FOUR_BIT})")
except Exception as exc: # noqa: BLE001
STATE["status"] = "error"
STATE["error"] = str(exc)
print(f"[sqlforge] model failed to load: {exc}")
app = FastAPI(title="SQLForge", description="Fine-tuned text-to-SQL demo")
@app.on_event("startup")
def _startup():
threading.Thread(target=_load, daemon=True).start()
def _db_path(db_id: str) -> Path | None:
"""Resolve a known example db_id to its .sqlite file (no path traversal)."""
if not db_id or "/" in db_id or "\\" in db_id or ".." in db_id:
return None
p = DB_DIR / db_id / f"{db_id}.sqlite"
return p if p.exists() else None
# --- API ------------------------------------------------------------------------
class GenerateRequest(BaseModel):
question: str
schema_text: str | None = None # raw CREATE TABLE text (custom mode)
db_id: str | None = None # example DB id (executes + self-corrects)
self_correct: bool = True
@app.get("/api/health")
def health():
return STATE
@app.get("/api/examples")
def examples():
return [e for e in EXAMPLES if _db_path(e["db_id"])]
@app.get("/api/schema")
def schema(db_id: str):
p = _db_path(db_id)
if not p:
raise HTTPException(404, f"unknown database '{db_id}'")
return {"db_id": db_id, "schema": schema_from_sqlite(p)}
@app.post("/api/generate")
def generate(req: GenerateRequest):
if STATE["status"] != "online":
raise HTTPException(503, f"model not ready ({STATE['status']})")
if not req.question.strip():
raise HTTPException(400, "question is required")
db_path = _db_path(req.db_id) if req.db_id else None
schema_text = req.schema_text
if db_path and not schema_text:
schema_text = schema_from_sqlite(db_path)
if not schema_text:
raise HTTPException(400, "provide a schema or pick an example database")
model, tok = _MODEL["model"], _MODEL["tok"]
trace: list = []
t0 = time.time()
with _LOCK:
if db_path and req.self_correct:
sql, attempts = generate_sql_with_retry(
model, tok, schema_text, req.question,
validate=lambda s: run_sql(db_path, s)[1],
max_retries=MAX_RETRIES, max_new_tokens=GEN_TOKENS, trace=trace)
else:
sql = generate_sql(model, tok, schema_text, req.question,
max_new_tokens=GEN_TOKENS)
attempts = 1
elapsed = round(time.time() - t0, 2)
resp = {"sql": sql, "attempts": attempts, "trace": trace,
"self_corrected": attempts > 1, "elapsed_s": elapsed,
"executed": False, "columns": None, "rows": None,
"row_count": None, "error": None}
# if we have the real DB, run the final query and return the result preview
if db_path:
conn = sqlite3.connect(str(db_path))
try:
cur = conn.execute(sql)
cols = [c[0] for c in cur.description] if cur.description else []
rows = cur.fetchmany(MAX_ROWS)
resp.update(executed=True, columns=cols,
rows=[list(r) for r in rows], row_count=len(rows))
except Exception as exc: # noqa: BLE001
resp.update(executed=True, error=str(exc))
finally:
conn.close()
return resp
# --- static frontend (mounted last so /api/* wins) ------------------------------
@app.get("/")
def index():
return FileResponse(STATIC / "index.html")
app.mount("/", StaticFiles(directory=str(STATIC)), name="static")