Spaces:
Sleeping
Sleeping
| """ | |
| Local text-to-SQL over SQLite databases. | |
| The public product path is: | |
| - load a Qwen2.5-Coder-7B base model plus a SQL LoRA adapter | |
| - introspect a SQLite schema | |
| - generate SQL from a natural-language question | |
| - optionally execute only read-only SQL against the database | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import re | |
| import sys | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from src.bird.inference import build_instruction, extract_sql | |
| from src.shared.schema_loader import get_schema_from_sqlite | |
| from src.shared.sqlite_executor import execute_sqlite_query | |
| DEFAULT_BASE = "Qwen/Qwen2.5-Coder-7B-Instruct" | |
| DEFAULT_ADAPTER = "jk200201/qwen2.5-coder-7b-sql-dpo" | |
| ADAPTERS = { | |
| "spider": "jk200201/qwen2.5-coder-7b-sql-dpo", | |
| "bird": "jk200201/qwen2.5-coder-7b-bird-dpo", | |
| "base": "", | |
| } | |
| def resolve_adapter(adapter: str | None) -> str | None: | |
| """Accept a shortcut, HF repo id, local path, empty string, or None.""" | |
| if adapter is None: | |
| return DEFAULT_ADAPTER | |
| if adapter in ADAPTERS: | |
| return ADAPTERS[adapter] or None | |
| return adapter or None | |
| def load_model( | |
| base_model: str = DEFAULT_BASE, | |
| adapter: str | None = DEFAULT_ADAPTER, | |
| use_4bit: bool = True, | |
| ): | |
| """Load the base model plus an optional LoRA adapter.""" | |
| adapter = resolve_adapter(adapter) | |
| if use_4bit and not torch.cuda.is_available(): | |
| raise RuntimeError( | |
| "4-bit inference needs a CUDA GPU. On Hugging Face Spaces, open " | |
| "Settings -> Hardware and choose a GPU such as Nvidia L4 before " | |
| "running the demo." | |
| ) | |
| print( | |
| f"Loading {base_model}" + (f" + {adapter}" if adapter else " (base only)"), | |
| file=sys.stderr, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| kwargs: dict[str, Any] = {"trust_remote_code": True, "device_map": "auto"} | |
| if use_4bit: | |
| kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| else: | |
| kwargs["torch_dtype"] = torch.bfloat16 | |
| model = AutoModelForCausalLM.from_pretrained(base_model, **kwargs) | |
| if adapter: | |
| model = PeftModel.from_pretrained(model, adapter) | |
| model.eval() | |
| return model, tokenizer | |
| def generate_sql( | |
| model, | |
| tokenizer, | |
| schema: str, | |
| question: str, | |
| evidence: str = "", | |
| max_new_tokens: int = 256, | |
| ) -> str: | |
| """Build the training-time prompt, greedily decode, and return SQL.""" | |
| instruction = build_instruction(question, schema, evidence) | |
| inputs = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": instruction}], | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| raw = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True) | |
| return extract_sql(raw) | |
| def is_read_only_sql(sql: str) -> bool: | |
| """ | |
| Conservative guard for product execution. | |
| We allow SELECT/WITH queries only and reject obvious mutating statements. | |
| This keeps the CLI/demo from modifying a user's database by accident. | |
| """ | |
| cleaned = re.sub(r"--.*?$|/\*.*?\*/", "", sql, flags=re.MULTILINE | re.DOTALL).strip() | |
| if not cleaned: | |
| return False | |
| first_token = cleaned.split(None, 1)[0].lower().rstrip(";") | |
| if first_token not in {"select", "with"}: | |
| return False | |
| blocked = re.search( | |
| r"\b(insert|update|delete|drop|alter|create|replace|attach|detach|vacuum|pragma)\b", | |
| cleaned, | |
| flags=re.IGNORECASE, | |
| ) | |
| return blocked is None | |
| def predict( | |
| db_path: str, | |
| question: str, | |
| model, | |
| tokenizer, | |
| evidence: str = "", | |
| execute: bool = True, | |
| max_new_tokens: int = 256, | |
| ) -> dict: | |
| """Introspect schema, generate SQL, and optionally run it read-only.""" | |
| schema = get_schema_from_sqlite(db_path) | |
| sql = generate_sql(model, tokenizer, schema, question, evidence, max_new_tokens) | |
| result = { | |
| "sql": sql, | |
| "columns": [], | |
| "rows": [], | |
| "row_count": 0, | |
| "error": None, | |
| "schema": schema, | |
| } | |
| if execute: | |
| if not is_read_only_sql(sql): | |
| result["error"] = "Refusing to execute non-read-only SQL. Use --no-exec to inspect it." | |
| return result | |
| exec_out = execute_sqlite_query(sql, db_path) | |
| result.update( | |
| columns=exec_out["columns"], | |
| rows=exec_out["rows"], | |
| row_count=exec_out["row_count"], | |
| error=exec_out["error"], | |
| ) | |
| return result | |
| def _print_result(result: dict) -> None: | |
| print("\nSQL") | |
| print(result["sql"]) | |
| if result["error"]: | |
| print("\nError") | |
| print(result["error"]) | |
| return | |
| print(f"\nResults ({result['row_count']} rows)") | |
| try: | |
| from tabulate import tabulate | |
| print(tabulate(result["rows"][:50], headers=result["columns"], tablefmt="github")) | |
| except ImportError: | |
| print(result["columns"]) | |
| for row in result["rows"][:50]: | |
| print(row) | |
| if result["row_count"] > 50: | |
| print(f"... ({result['row_count'] - 50} more rows)") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Local text-to-SQL over a SQLite database") | |
| parser.add_argument("--db", required=True, help="Path to a .sqlite database file") | |
| parser.add_argument("--q", "--question", dest="question", required=True) | |
| parser.add_argument( | |
| "--adapter", | |
| default=DEFAULT_ADAPTER, | |
| help="LoRA adapter, local path, or shortcut: spider, bird, base", | |
| ) | |
| parser.add_argument("--base-model", default=DEFAULT_BASE) | |
| parser.add_argument("--evidence", default="", help="Optional BIRD-style domain hint") | |
| parser.add_argument("--bf16", action="store_true", help="Load in bf16 instead of 4-bit") | |
| parser.add_argument("--no-exec", action="store_true", help="Generate SQL without running it") | |
| parser.add_argument("--max-new-tokens", type=int, default=256) | |
| args = parser.parse_args() | |
| db_path = Path(args.db) | |
| if not db_path.exists(): | |
| sys.exit(f"Database not found: {db_path}") | |
| model, tokenizer = load_model( | |
| base_model=args.base_model, | |
| adapter=args.adapter, | |
| use_4bit=not args.bf16, | |
| ) | |
| result = predict( | |
| str(db_path), | |
| args.question, | |
| model, | |
| tokenizer, | |
| evidence=args.evidence, | |
| execute=not args.no_exec, | |
| max_new_tokens=args.max_new_tokens, | |
| ) | |
| _print_result(result) | |
| if __name__ == "__main__": | |
| main() | |