import logging import os import torch from flask import Flask, request, render_template_string, jsonify from flask_cors import CORS from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from config import MODEL_PATH, HF_MODEL_ID, MAX_INPUT_LENGTH, MAX_OUTPUT_LENGTH, NUM_BEAMS, PROMPT_TEMPLATE, MAX_QUESTION_LENGTH, MAX_SCHEMA_LENGTH from schema import truncate_schema logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) app = Flask(__name__) CORS(app) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = None model = None def get_model(): global tokenizer, model if model is None: if os.path.exists(MODEL_PATH): source = MODEL_PATH else: log.info(f"Local model not found at '{MODEL_PATH}', downloading from HuggingFace: {HF_MODEL_ID}") source = HF_MODEL_ID tokenizer = AutoTokenizer.from_pretrained(source) model = AutoModelForSeq2SeqLM.from_pretrained(source) model = model.to(device) model.eval() log.info(f"Model loaded from {source} on {device}") return tokenizer, model def predict(question, db_id="unknown", schema="unknown"): schema = truncate_schema(schema, MAX_SCHEMA_LENGTH) input_text = PROMPT_TEMPLATE.format(db_id=db_id, schema=schema, question=question) tokenizer, model = get_model() tokenized_input = tokenizer(input_text, max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt") tokenized_outputs = model.generate( input_ids=tokenized_input["input_ids"].to(device), attention_mask=tokenized_input["attention_mask"].to(device), max_length=MAX_OUTPUT_LENGTH, num_beams=NUM_BEAMS, ) return tokenizer.decode(tokenized_outputs[0], skip_special_tokens=True) HTML = """ SQLator — Natural Language to SQL
Fine-tuned CodeT5+ Model

SQLator

Ask a question in plain English. Get a SQL query back.

{% if error %}
{{ error }}
{% endif %} {% if sql %}
Input
{{ question }}
Generated SQL
{{ sql }}
{% endif %}
""" @app.route("/health", methods=["GET"]) def health(): return jsonify({"status": "ok"}) @app.route("/predict", methods=["POST"]) def predict_api(): data = request.get_json(silent=True) or {} question = (data.get("question") or "").strip() db_id = (data.get("db_id") or "").strip() or "unknown" if not question: return jsonify({"error": "Please enter a question."}), 400 if len(question) > MAX_QUESTION_LENGTH: return jsonify({"error": f"Question is too long (max {MAX_QUESTION_LENGTH} characters)."}), 400 try: log.info(f"API predict: question='{question}' db_id='{db_id}'") sql = predict(question, db_id, schema="unknown") return jsonify({"sql": sql}) except Exception as e: log.exception("Prediction failed") return jsonify({"error": f"Inference failed: {e}"}), 500 @app.route("/", methods=["GET", "POST"]) def home(): question = None db_id = None schema = None sql = None error = None if request.method == "POST": question = request.form.get("question", "").strip() db_id = request.form.get("db_id", "").strip() or "unknown" schema = request.form.get("schema", "").strip() or "unknown" if not question: error = "Please enter a question." elif len(question) > MAX_QUESTION_LENGTH: error = f"Question is too long (max {MAX_QUESTION_LENGTH} characters)." else: log.info(f"Predicting for question='{question}' db_id='{db_id}'") sql = predict(question, db_id, schema=schema) return render_template_string(HTML, question=question, db_id=db_id, schema=schema, sql=sql, error=error) if __name__ == "__main__": debug = os.getenv("FLASK_DEBUG", "false").lower() == "true" app.run(debug=debug)