import logging import os import torch from flask import Flask, request, render_template_string, jsonify from flask_cors import CORS from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from config import MODEL_PATH, HF_MODEL_ID, MAX_INPUT_LENGTH, MAX_OUTPUT_LENGTH, NUM_BEAMS, PROMPT_TEMPLATE, MAX_QUESTION_LENGTH, MAX_SCHEMA_LENGTH from schema import truncate_schema logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) app = Flask(__name__) CORS(app) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = None model = None def get_model(): global tokenizer, model if model is None: if os.path.exists(MODEL_PATH): source = MODEL_PATH else: log.info(f"Local model not found at '{MODEL_PATH}', downloading from HuggingFace: {HF_MODEL_ID}") source = HF_MODEL_ID tokenizer = AutoTokenizer.from_pretrained(source) model = AutoModelForSeq2SeqLM.from_pretrained(source) model = model.to(device) model.eval() log.info(f"Model loaded from {source} on {device}") return tokenizer, model def predict(question, db_id="unknown", schema="unknown"): schema = truncate_schema(schema, MAX_SCHEMA_LENGTH) input_text = PROMPT_TEMPLATE.format(db_id=db_id, schema=schema, question=question) tokenizer, model = get_model() tokenized_input = tokenizer(input_text, max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt") tokenized_outputs = model.generate( input_ids=tokenized_input["input_ids"].to(device), attention_mask=tokenized_input["attention_mask"].to(device), max_length=MAX_OUTPUT_LENGTH, num_beams=NUM_BEAMS, ) return tokenizer.decode(tokenized_outputs[0], skip_special_tokens=True) HTML = """
Ask a question in plain English. Get a SQL query back.