| import os |
| os.environ['HF_HOME'] = '/tmp/hf_home' |
| os.environ['HF_DATASETS_CACHE'] = '/tmp/hf_datasets_cache' |
| os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache' |
|
|
| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import uvicorn |
|
|
| MODEL_NAME = "16pramodh/t2s_model" |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
|
|
| app = FastAPI() |
|
|
| class QueryRequest(BaseModel): |
| text: str |
|
|
| @app.get("/") |
| def read_root(): |
| return {"status": "running"} |
|
|
| @app.post("/predict") |
| def predict(request: QueryRequest): |
| inputs = tokenizer(request.text, return_tensors="pt") |
| outputs = model.generate(**inputs, max_length=256) |
| return {"sql": tokenizer.decode(outputs[0], skip_special_tokens=True)} |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |